Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,120 @@ def on_request_error(self, query, consistency, error, retry_num):
return self.RETHROW, None, None


class LWTRetryPolicy(ExponentialBackoffRetryPolicy):
"""
A retry policy tailored for Lightweight Transaction (LWT) queries.

LWT queries use Paxos consensus, where the first replica in the token ring
acts as the Paxos coordinator (leader). Retrying LWT queries on a *different*
host causes Paxos contention — the new coordinator must compete with the
original one, potentially causing cascading timeouts.

This policy addresses that by:

- **CAS write timeouts**: Retrying on the **same host** (the Paxos coordinator)
with exponential backoff, giving the coordinator time to complete the Paxos round.
- **CAS read timeouts** (serial consistency): Retrying on the same host.
- **Unavailable at serial consistency**: Retrying on the **next host**, since the
Paxos phase failed on this node (not enough replicas alive to form quorum).
- **Non-CAS operations**: Delegating to the standard :class:`ExponentialBackoffRetryPolicy`
behavior.

This is modeled after gocql's ``LWTRetryPolicy`` interface, which retries LWT
queries on the same host to avoid Paxos contention.

Example usage::

from cassandra.cluster import Cluster
from cassandra.policies import LWTRetryPolicy

# Use as the default retry policy for the cluster
cluster = Cluster(
default_retry_policy=LWTRetryPolicy(max_num_retries=3)
)

# Or assign to a specific statement
statement.retry_policy = LWTRetryPolicy(max_num_retries=5)

:param max_num_retries: Maximum number of retry attempts (default: 3).
:param min_interval: Initial backoff delay in seconds (default: 0.1).
:param max_interval: Maximum backoff delay in seconds (default: 10.0).
"""

def __init__(self, max_num_retries=3, min_interval=0.1, max_interval=10.0,
*args, **kwargs):
super(LWTRetryPolicy, self).__init__(
max_num_retries=max_num_retries,
min_interval=min_interval,
max_interval=max_interval,
*args, **kwargs)

def on_write_timeout(self, query, consistency, write_type,
required_responses, received_responses, retry_num):
"""
For CAS (LWT) write timeouts, retry on the **same host** with exponential
backoff. Retrying on a different host would cause Paxos contention.

For non-CAS writes, delegates to the base ExponentialBackoffRetryPolicy
behavior (retry BATCH_LOG only, RETHROW otherwise).
"""
if retry_num >= self.max_num_retries:
return self.RETHROW, None, None

if write_type == WriteType.CAS:
# Retry on the SAME host — this is the Paxos coordinator.
# Moving to another host causes contention in the Paxos protocol.
return self.RETRY, consistency, self._calculate_backoff(retry_num)

# Non-CAS: delegate to parent (retries BATCH_LOG, rethrows others)
return super(LWTRetryPolicy, self).on_write_timeout(
query, consistency, write_type,
required_responses, received_responses, retry_num)

def on_read_timeout(self, query, consistency, required_responses,
received_responses, data_retrieved, retry_num):
"""
For reads at serial consistency (CAS reads), retry on the **same host**
with backoff.

For non-serial reads, delegates to the base ExponentialBackoffRetryPolicy
behavior.
"""
if retry_num >= self.max_num_retries:
return self.RETHROW, None, None

if ConsistencyLevel.is_serial(consistency):
# Serial read = CAS/Paxos read. Retry on same host.
return self.RETRY, consistency, self._calculate_backoff(retry_num)

# Non-serial: delegate to parent
return super(LWTRetryPolicy, self).on_read_timeout(
query, consistency, required_responses,
received_responses, data_retrieved, retry_num)

def on_unavailable(self, query, consistency, required_replicas,
alive_replicas, retry_num):
"""
For serial consistency (CAS/Paxos phase), retry on the **next host** —
this node couldn't form a Paxos quorum, so a different coordinator
might see a different set of available replicas.

For non-serial consistency, delegates to the base ExponentialBackoffRetryPolicy
behavior.
"""
if retry_num >= self.max_num_retries:
return self.RETHROW, None, None

if ConsistencyLevel.is_serial(consistency):
# Paxos phase failed — not enough replicas for serial quorum.
# Try a different coordinator; it might have better connectivity.
return self.RETRY_NEXT_HOST, consistency, self._calculate_backoff(retry_num)

# Non-serial: delegate to parent
return super(LWTRetryPolicy, self).on_unavailable(
query, consistency, required_replicas, alive_replicas, retry_num)


class AddressTranslator(object):
"""
Interface for translating cluster-defined endpoints.
Expand Down
256 changes: 255 additions & 1 deletion tests/unit/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
RetryPolicy, WriteType,
DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy,
LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy,
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy)
IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy,
LWTRetryPolicy)
from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint
from cassandra.pool import Host
from cassandra.query import Statement
Expand Down Expand Up @@ -1408,6 +1409,259 @@ def test_calculate_backoff(self):
assert d < delay + (0.1 / 2), f"d={d} attempts={attempts}, delay={delay}"


class LWTRetryPolicyTest(unittest.TestCase):
"""Tests for LWTRetryPolicy — LWT-aware retry with same-host preference."""

def _make_policy(self, max_retries=3):
return LWTRetryPolicy(max_num_retries=max_retries)

# --- CAS write timeout: retry on SAME host ---

def test_cas_write_timeout_retries_same_host(self):
"""CAS write timeout on first attempt should retry on SAME host."""
policy = self._make_policy()
retry, consistency, delay = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.CAS,
required_responses=3, received_responses=1, retry_num=0)
assert retry == RetryPolicy.RETRY
assert consistency == ConsistencyLevel.QUORUM
assert delay is not None and delay > 0

def test_cas_write_timeout_retries_with_backoff(self):
"""CAS write timeout backoff delay should increase with retry_num."""
policy = self._make_policy(max_retries=5)
delays = []
for attempt in range(3):
_, _, delay = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.CAS,
required_responses=3, received_responses=1, retry_num=attempt)
delays.append(delay)
# Delays should generally increase (with some jitter tolerance)
# delay_0 ~ 0.1s, delay_1 ~ 0.2s, delay_2 ~ 0.4s
assert delays[0] < delays[2], (
f"Backoff should increase: delays={delays}")

def test_cas_write_timeout_max_retries_exceeded(self):
"""CAS write timeout should RETHROW when max retries exceeded."""
policy = self._make_policy(max_retries=2)
retry, consistency, delay = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.CAS,
required_responses=3, received_responses=1, retry_num=2)
assert retry == RetryPolicy.RETHROW

def test_cas_write_timeout_preserves_consistency(self):
"""CAS retry should preserve the original consistency level."""
policy = self._make_policy()
for cl in [ConsistencyLevel.QUORUM, ConsistencyLevel.LOCAL_QUORUM,
ConsistencyLevel.ONE, ConsistencyLevel.ALL]:
retry, consistency, _ = policy.on_write_timeout(
query=None, consistency=cl,
write_type=WriteType.CAS,
required_responses=3, received_responses=1, retry_num=0)
assert retry == RetryPolicy.RETRY
assert consistency == cl, f"Expected {cl}, got {consistency}"

# --- Non-CAS write timeout: delegate to parent ---

def test_simple_write_timeout_rethrows(self):
"""SIMPLE write timeout should RETHROW (same as base policy)."""
policy = self._make_policy()
retry, consistency, delay = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.SIMPLE,
required_responses=3, received_responses=1, retry_num=0)
assert retry == RetryPolicy.RETHROW

def test_batch_log_write_timeout_retries(self):
"""BATCH_LOG write timeout should retry (inherited from base)."""
policy = self._make_policy()
retry, consistency, delay = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.BATCH_LOG,
required_responses=3, received_responses=1, retry_num=0)
assert retry == RetryPolicy.RETRY
assert consistency == ConsistencyLevel.QUORUM

def test_counter_write_timeout_rethrows(self):
"""COUNTER write timeout should RETHROW (same as base policy)."""
policy = self._make_policy()
retry, consistency, delay = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.COUNTER,
required_responses=3, received_responses=1, retry_num=0)
assert retry == RetryPolicy.RETHROW

# --- Serial (CAS) read timeout: retry on SAME host ---

def test_serial_read_timeout_retries_same_host(self):
"""Read timeout at SERIAL consistency should retry on SAME host."""
policy = self._make_policy()
retry, consistency, delay = policy.on_read_timeout(
query=None, consistency=ConsistencyLevel.SERIAL,
required_responses=3, received_responses=1,
data_retrieved=False, retry_num=0)
assert retry == RetryPolicy.RETRY
assert consistency == ConsistencyLevel.SERIAL
assert delay is not None and delay > 0

def test_local_serial_read_timeout_retries_same_host(self):
"""Read timeout at LOCAL_SERIAL should retry on SAME host."""
policy = self._make_policy()
retry, consistency, delay = policy.on_read_timeout(
query=None, consistency=ConsistencyLevel.LOCAL_SERIAL,
required_responses=3, received_responses=1,
data_retrieved=False, retry_num=0)
assert retry == RetryPolicy.RETRY
assert consistency == ConsistencyLevel.LOCAL_SERIAL
assert delay is not None and delay > 0

def test_serial_read_timeout_max_retries_exceeded(self):
"""Serial read timeout should RETHROW when max retries exceeded."""
policy = self._make_policy(max_retries=1)
retry, consistency, delay = policy.on_read_timeout(
query=None, consistency=ConsistencyLevel.SERIAL,
required_responses=3, received_responses=1,
data_retrieved=False, retry_num=1)
assert retry == RetryPolicy.RETHROW

# --- Non-serial read timeout: delegate to parent ---

def test_non_serial_read_timeout_delegates_to_parent(self):
"""Non-serial read timeout should use base policy behavior."""
policy = self._make_policy()
# Base: retry if enough responses but no data
retry, consistency, delay = policy.on_read_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
required_responses=2, received_responses=2,
data_retrieved=False, retry_num=0)
assert retry == RetryPolicy.RETRY
assert consistency == ConsistencyLevel.QUORUM

# Base: rethrow if we got data
retry, consistency, delay = policy.on_read_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
required_responses=2, received_responses=2,
data_retrieved=True, retry_num=0)
assert retry == RetryPolicy.RETHROW

# --- Serial unavailable: retry on NEXT host ---

def test_serial_unavailable_retries_next_host(self):
"""Unavailable at SERIAL should retry on NEXT host."""
policy = self._make_policy()
retry, consistency, delay = policy.on_unavailable(
query=None, consistency=ConsistencyLevel.SERIAL,
required_replicas=3, alive_replicas=1, retry_num=0)
assert retry == RetryPolicy.RETRY_NEXT_HOST
assert consistency == ConsistencyLevel.SERIAL
assert delay is not None and delay > 0

def test_local_serial_unavailable_retries_next_host(self):
"""Unavailable at LOCAL_SERIAL should retry on NEXT host."""
policy = self._make_policy()
retry, consistency, delay = policy.on_unavailable(
query=None, consistency=ConsistencyLevel.LOCAL_SERIAL,
required_replicas=3, alive_replicas=1, retry_num=0)
assert retry == RetryPolicy.RETRY_NEXT_HOST
assert consistency == ConsistencyLevel.LOCAL_SERIAL
assert delay is not None and delay > 0

def test_serial_unavailable_max_retries_exceeded(self):
"""Serial unavailable should RETHROW when max retries exceeded."""
policy = self._make_policy(max_retries=1)
retry, consistency, delay = policy.on_unavailable(
query=None, consistency=ConsistencyLevel.SERIAL,
required_replicas=3, alive_replicas=1, retry_num=1)
assert retry == RetryPolicy.RETHROW

# --- Non-serial unavailable: delegate to parent ---

def test_non_serial_unavailable_delegates_to_parent(self):
"""Non-serial unavailable should use base policy behavior."""
policy = self._make_policy()
# Base: RETRY_NEXT_HOST on first attempt
retry, consistency, delay = policy.on_unavailable(
query=None, consistency=ConsistencyLevel.QUORUM,
required_replicas=3, alive_replicas=1, retry_num=0)
assert retry == RetryPolicy.RETRY_NEXT_HOST

# --- on_request_error: inherited from parent ---

def test_request_error_retries_next_host(self):
"""Request errors should retry on next host (inherited behavior)."""
policy = self._make_policy()
retry, consistency, delay = policy.on_request_error(
query=None, consistency=ConsistencyLevel.QUORUM,
error=Exception("overloaded"), retry_num=0)
assert retry == RetryPolicy.RETRY_NEXT_HOST

def test_request_error_max_retries_exceeded(self):
"""Request errors should RETHROW when max retries exceeded."""
policy = self._make_policy(max_retries=1)
retry, consistency, delay = policy.on_request_error(
query=None, consistency=ConsistencyLevel.QUORUM,
error=Exception("overloaded"), retry_num=1)
assert retry == RetryPolicy.RETHROW

# --- Constructor defaults ---

def test_default_constructor(self):
"""LWTRetryPolicy should have sensible defaults."""
policy = LWTRetryPolicy()
assert policy.max_num_retries == 3
assert policy.min_interval == 0.1
assert policy.max_interval == 10.0

def test_custom_constructor(self):
"""LWTRetryPolicy should accept custom parameters."""
policy = LWTRetryPolicy(max_num_retries=5, min_interval=0.5, max_interval=30.0)
assert policy.max_num_retries == 5
assert policy.min_interval == 0.5
assert policy.max_interval == 30.0

def test_inherits_exponential_backoff(self):
"""LWTRetryPolicy should inherit from ExponentialBackoffRetryPolicy."""
policy = LWTRetryPolicy()
assert isinstance(policy, ExponentialBackoffRetryPolicy)
assert isinstance(policy, RetryPolicy)

# --- Verify 3-tuple return format for all methods ---

def test_all_methods_return_3_tuples(self):
"""All retry decisions should return 3-tuples (decision, cl, delay)."""
policy = self._make_policy()

# CAS write timeout
result = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.CAS,
required_responses=3, received_responses=1, retry_num=0)
assert len(result) == 3, f"Expected 3-tuple, got {result}"

# Serial read timeout
result = policy.on_read_timeout(
query=None, consistency=ConsistencyLevel.SERIAL,
required_responses=3, received_responses=1,
data_retrieved=False, retry_num=0)
assert len(result) == 3, f"Expected 3-tuple, got {result}"

# Serial unavailable
result = policy.on_unavailable(
query=None, consistency=ConsistencyLevel.SERIAL,
required_replicas=3, alive_replicas=1, retry_num=0)
assert len(result) == 3, f"Expected 3-tuple, got {result}"

# RETHROW cases
result = policy.on_write_timeout(
query=None, consistency=ConsistencyLevel.QUORUM,
write_type=WriteType.SIMPLE,
required_responses=3, received_responses=1, retry_num=0)
assert len(result) == 3, f"Expected 3-tuple, got {result}"


class WhiteListRoundRobinPolicyTest(unittest.TestCase):

def test_hosts_with_hostname(self):
Expand Down
Loading