diff --git a/cassandra/policies.py b/cassandra/policies.py index e742708019..26658a2822 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -502,19 +502,29 @@ def make_query_plan(self, working_keyspace=None, query=None): yield host return + is_lwt = query.is_lwt() + replicas = [] tablet = self._cluster_metadata._tablets.get_tablet_for_key( keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) if tablet is not None: - replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) - child_plan = child.make_query_plan(keyspace, query) + if is_lwt: + # For LWT queries, preserve the tablet's natural replica order + # so that the first replica (Paxos leader) is tried first. + # Using the child policy's round-robin order would lose this. + replicas = [self._cluster_metadata.get_host_by_host_id(host_id) + for host_id, _shard in tablet.replicas] + replicas = [r for r in replicas if r is not None] + else: + replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) + child_plan = child.make_query_plan(keyspace, query) - replicas = [host for host in child_plan if host.host_id in replicas_mapped] + replicas = [host for host in child_plan if host.host_id in replicas_mapped] else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) - if self.shuffle_replicas and not query.is_lwt(): + if self.shuffle_replicas and not is_lwt: shuffle(replicas) def yield_in_order(hosts): @@ -523,10 +533,26 @@ def yield_in_order(hosts): if replica.is_up and child.distance(replica) == distance: yield replica - # yield replicas: local_rack, local, remote - yield from yield_in_order(replicas) - # yield rest of the cluster: local_rack, local, remote - yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) + if is_lwt: + # For LWT queries, yield replicas in their natural token-ring order + # (first replica = Paxos leader). Do NOT re-sort by distance, as + # that could demote the Paxos leader when using RackAwareRoundRobinPolicy + # (the leader might be in a different rack than the client). + # Only skip hosts that are down or IGNORED by the child policy. + replicas_yielded = set() + for replica in replicas: + if replica.is_up and child.distance(replica) != HostDistance.IGNORED: + replicas_yielded.add(replica) + yield replica + # Yield remaining hosts (non-replicas) in distance order as fallback + yield from yield_in_order( + [host for host in child.make_query_plan(keyspace, query) + if host not in replicas_yielded]) + else: + # yield replicas: local_rack, local, remote + yield from yield_in_order(replicas) + # yield rest of the cluster: local_rack, local, remote + yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..656f4db92b 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -944,6 +944,425 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): assert patched_shuffle.call_count == 1 +class LWTTokenAwareRoutingTest(unittest.TestCase): + """ + Tests for LWT-aware routing in TokenAwarePolicy. + + These tests verify that LWT queries yield replicas in their natural + token-ring order (first replica = Paxos leader) rather than re-sorting + by distance. This is critical because the Paxos leader might be in a + different rack than the client, and distance-based reordering would + demote it, causing an extra network hop for every LWT operation. + + See: https://github.com/scylladb/python-driver/issues/780 + https://github.com/scylladb/python-driver/issues/781 + """ + + @staticmethod + def _make_lwt_statement(routing_key, keyspace="ks", table=None): + """Create a Statement that reports is_lwt()=True.""" + stmt = Statement(routing_key=routing_key, keyspace=keyspace) + stmt.is_lwt = lambda: True + if table: + stmt.table = table + return stmt + + def _make_cluster_and_hosts(self, num_hosts=6): + """Create a mock cluster with hosts spread across DCs and racks.""" + hosts = [ + Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) + for i in range(num_hosts) + ] + for host in hosts: + host.set_up() + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.get_tablet_for_key.return_value = None + return cluster, hosts + + def test_lwt_rack_aware_preserves_ring_order(self): + """ + With RackAwareRoundRobinPolicy, LWT queries should yield replicas + in natural token-ring order, even if the Paxos leader (first replica) + is in a different rack than the client. + + Bug: Without the fix, yield_in_order() re-sorts replicas by distance + (LOCAL_RACK before LOCAL), which can demote the Paxos leader. + """ + cluster, hosts = self._make_cluster_and_hosts(6) + + # Paxos leader (hosts[0]) is in rack2, client is in rack1 + hosts[0].set_location_info("dc1", "rack2") # Paxos leader, different rack + hosts[1].set_location_info("dc1", "rack1") # same rack as client + hosts[2].set_location_info("dc1", "rack2") # different rack + hosts[3].set_location_info("dc1", "rack1") # non-replica, same rack + hosts[4].set_location_info("dc1", "rack2") # non-replica, different rack + hosts[5].set_location_info("dc2", "rack1") # remote DC + + # Ring order: hosts[0] (rack2), hosts[1] (rack1), hosts[2] (rack2) + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1], hosts[2]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=0) + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # LWT: replicas should be in ring order, NOT distance order + # hosts[0] (rack2/Paxos leader) MUST come first, not be demoted + assert qplan[0] == hosts[0], ( + "Paxos leader (rack2) should be first, got %s" % qplan[0].endpoint + ) + assert qplan[1] == hosts[1] + assert qplan[2] == hosts[2] + + def test_lwt_dc_aware_preserves_ring_order(self): + """ + With DCAwareRoundRobinPolicy, LWT queries should yield replicas + in natural token-ring order. (DCAware doesn't distinguish racks, + so ring order is already preserved in practice, but this test + ensures the LWT path is active.) + """ + cluster, hosts = self._make_cluster_and_hosts(4) + + for h in hosts[:3]: + h.set_location_info("dc1", "rack1") + # Only one remote host so it's guaranteed to get REMOTE distance + hosts[3].set_location_info("dc2", "rack1") + + # Ring order: hosts[0], hosts[1], hosts[2] are local replicas + # hosts[3] is a remote replica + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1], hosts[2], hosts[3]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # First 3 should be local replicas in ring order + assert qplan[0] == hosts[0] + assert qplan[1] == hosts[1] + assert qplan[2] == hosts[2] + # Then remote replica (not IGNORED since it's the only dc2 host) + assert qplan[3] == hosts[3] + + def test_non_lwt_rack_aware_uses_distance_order(self): + """ + Non-LWT queries with RackAwareRoundRobinPolicy should still sort + replicas by distance (LOCAL_RACK first, then LOCAL, then REMOTE). + This verifies the fix doesn't break non-LWT behavior. + """ + cluster, hosts = self._make_cluster_and_hosts(4) + + hosts[0].set_location_info("dc1", "rack2") # different rack + hosts[1].set_location_info("dc1", "rack1") # same rack as client + hosts[2].set_location_info("dc1", "rack2") # different rack + hosts[3].set_location_info("dc1", "rack1") # same rack + + # Ring order: hosts[0] first, hosts[1] second + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1], hosts[2], hosts[3]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=0) + ) + policy.populate(cluster, hosts) + + # Non-LWT query: should use distance ordering + non_lwt_query = Statement(routing_key=b"key1", keyspace="ks") + non_lwt_query.is_lwt = lambda: False + qplan = list(policy.make_query_plan(None, non_lwt_query)) + + # LOCAL_RACK hosts (rack1) should come before LOCAL hosts (rack2) + rack1_hosts = {hosts[1], hosts[3]} + rack2_hosts = {hosts[0], hosts[2]} + assert set(qplan[:2]) == rack1_hosts, "rack1 hosts should be first for non-LWT" + assert set(qplan[2:]) == rack2_hosts, "rack2 hosts should come after rack1" + + def test_lwt_skips_down_hosts(self): + """LWT routing should skip hosts that are down.""" + cluster, hosts = self._make_cluster_and_hosts(4) + + for h in hosts: + h.set_location_info("dc1", "rack1") + + # Mark hosts[0] (Paxos leader) as down + hosts[0].set_up() + hosts[0].is_up = False + + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1], hosts[2]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy("dc1") + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # hosts[0] is down, so first should be hosts[1] + assert hosts[0] not in qplan, "Down host should not appear in query plan" + assert qplan[0] == hosts[1], "First up replica should be tried first" + assert qplan[1] == hosts[2] + + def test_lwt_skips_ignored_hosts(self): + """LWT routing should skip hosts with IGNORED distance.""" + cluster, hosts = self._make_cluster_and_hosts(4) + + for h in hosts[:3]: + h.set_location_info("dc1", "rack1") + hosts[3].set_location_info("dc2", "rack1") # remote DC + + # All 4 are replicas, but hosts[3] is in dc2 (IGNORED with default dc-aware) + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1], hosts[2], hosts[3]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # hosts[3] is IGNORED (remote DC, used_hosts_per_remote_dc=0) + assert hosts[3] not in qplan, "IGNORED host should not appear in query plan" + # Local replicas in ring order + assert qplan[0] == hosts[0] + assert qplan[1] == hosts[1] + assert qplan[2] == hosts[2] + + def test_lwt_tablet_preserves_natural_order(self): + """ + LWT queries with tablet routing should use tablet's natural + replica order (host_id order in tablet.replicas), not the child + policy's round-robin order. + """ + cluster, hosts = self._make_cluster_and_hosts(4) + + for h in hosts: + h.set_location_info("dc1", "rack1") + + # Set up tablet with replicas in specific order + tablet = Tablet( + first_token=-9223372036854775808, + last_token=9223372036854775807, + replicas=[(hosts[2].host_id, 0), (hosts[0].host_id, 0), (hosts[1].host_id, 0)] + ) + cluster.metadata._tablets.get_tablet_for_key.return_value = tablet + + # Mock get_host_by_host_id to return hosts by their host_id + host_id_map = {h.host_id: h for h in hosts} + cluster.metadata.get_host_by_host_id.side_effect = lambda hid: host_id_map.get(hid) + + policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy("dc1") + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1", table="test_table") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # Should follow tablet's natural order: hosts[2], hosts[0], hosts[1] + assert qplan[0] == hosts[2], ( + "First tablet replica should be first, got %s" % qplan[0].endpoint + ) + assert qplan[1] == hosts[0] + assert qplan[2] == hosts[1] + + def test_non_lwt_tablet_uses_child_policy_order(self): + """ + Non-LWT queries with tablet routing should use the child policy's + round-robin order (existing behavior), not the tablet's natural order. + """ + cluster, hosts = self._make_cluster_and_hosts(4) + + for h in hosts: + h.set_location_info("dc1", "rack1") + + tablet = Tablet( + first_token=-9223372036854775808, + last_token=9223372036854775807, + replicas=[(hosts[2].host_id, 0), (hosts[0].host_id, 0), (hosts[1].host_id, 0)] + ) + cluster.metadata._tablets.get_tablet_for_key.return_value = tablet + + policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy("dc1") + ) + policy.populate(cluster, hosts) + + non_lwt_query = Statement(routing_key=b"key1", keyspace="ks") + non_lwt_query.is_lwt = lambda: False + non_lwt_query.table = "test_table" + qplan = list(policy.make_query_plan(None, non_lwt_query)) + + # Non-LWT should use child policy's order (round-robin filtered by tablet replicas) + # The important thing is that it does NOT necessarily follow tablet order + replica_ids = {hosts[0].host_id, hosts[1].host_id, hosts[2].host_id} + first_three = set(qplan[:3]) + assert all(h.host_id in replica_ids for h in first_three), "First 3 should be the tablet replicas" + + def test_lwt_does_not_shuffle(self): + """LWT queries should never shuffle replicas, even with shuffle_replicas=True.""" + cluster, hosts = self._make_cluster_and_hosts(4) + + for h in hosts: + h.set_location_info("dc1", "rack1") + + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1], hosts[2]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy("dc1"), + shuffle_replicas=True + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + + # Run multiple times - order should always be the same + orders = set() + for _ in range(20): + qplan = list(policy.make_query_plan(None, lwt_query)) + order = tuple(h.endpoint for h in qplan[:3]) + orders.add(order) + + assert len(orders) == 1, ( + "LWT query plan should be deterministic (no shuffle), " + "got %d distinct orders" % len(orders) + ) + + def test_lwt_fallback_hosts_use_distance_order(self): + """ + For LWT queries, non-replica fallback hosts should still use + distance-based ordering (LOCAL_RACK, LOCAL, REMOTE). + """ + cluster, hosts = self._make_cluster_and_hosts(6) + + hosts[0].set_location_info("dc1", "rack1") # replica + hosts[1].set_location_info("dc1", "rack2") # non-replica, LOCAL + hosts[2].set_location_info("dc1", "rack1") # non-replica, LOCAL_RACK + hosts[3].set_location_info("dc1", "rack2") # non-replica, LOCAL + hosts[4].set_location_info("dc1", "rack1") # non-replica, LOCAL_RACK + hosts[5].set_location_info("dc2", "rack1") # non-replica, REMOTE + + # Only hosts[0] is a replica + def get_replicas(keyspace, packed_key): + return [hosts[0]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=0) + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # First should be the replica (ring order) + assert qplan[0] == hosts[0] + # Remaining hosts should be in distance order: LOCAL_RACK first + fallback = qplan[1:] + rack1_hosts = [h for h in fallback if h.rack == "rack1" and h.datacenter == "dc1"] + rack2_hosts = [h for h in fallback if h.rack == "rack2" and h.datacenter == "dc1"] + if rack1_hosts and rack2_hosts: + # LOCAL_RACK hosts should appear before LOCAL hosts + first_rack1_idx = fallback.index(rack1_hosts[0]) + first_rack2_idx = fallback.index(rack2_hosts[0]) + assert first_rack1_idx < first_rack2_idx, ( + "Fallback hosts should use distance order: LOCAL_RACK before LOCAL" + ) + + def test_lwt_with_all_replicas_down(self): + """LWT with all replicas down should fall back to non-replica hosts.""" + cluster, hosts = self._make_cluster_and_hosts(4) + + for h in hosts: + h.set_location_info("dc1", "rack1") + + # Mark replica hosts as down + hosts[0].is_up = False + hosts[1].is_up = False + + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy("dc1") + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # Should still get fallback hosts + assert len(qplan) > 0, "Should have fallback hosts even if replicas are down" + assert hosts[0] not in qplan + assert hosts[1] not in qplan + # hosts[2] and hosts[3] should be in the fallback + assert hosts[2] in qplan + assert hosts[3] in qplan + + def test_lwt_rack_aware_multiple_rack2_replicas(self): + """ + Regression test: with multiple replicas in a non-client rack, + the LWT path should still preserve ring order for all of them. + """ + cluster, hosts = self._make_cluster_and_hosts(6) + + # All replicas in rack2, none in client's rack1 + hosts[0].set_location_info("dc1", "rack2") + hosts[1].set_location_info("dc1", "rack2") + hosts[2].set_location_info("dc1", "rack2") + hosts[3].set_location_info("dc1", "rack1") # non-replica + hosts[4].set_location_info("dc1", "rack1") # non-replica + hosts[5].set_location_info("dc1", "rack1") # non-replica + + def get_replicas(keyspace, packed_key): + return [hosts[0], hosts[1], hosts[2]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy( + RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=0) + ) + policy.populate(cluster, hosts) + + lwt_query = self._make_lwt_statement(routing_key=b"key1") + qplan = list(policy.make_query_plan(None, lwt_query)) + + # All rack2 replicas should come in ring order: 0, 1, 2 + assert qplan[0] == hosts[0] + assert qplan[1] == hosts[1] + assert qplan[2] == hosts[2] + + + class ConvictionPolicyTest(unittest.TestCase): def test_not_implemented(self): """