Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions benchmarks/bench_collection_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python
# Copyright ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Benchmark: collection serialization - BytesIO vs b"".join(parts).

Measures end-to-end serialize_safe performance for List, Map, Tuple, and UserType
with varying collection sizes.
"""

import timeit
import sys
import os

# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from cassandra.cqltypes import (
ListType,
SetType,
MapType,
TupleType,
UserType,
Int32Type,
UTF8Type,
FloatType,
)

PROTOCOL_VERSION = 4

# Build parameterized types
ListOfInt = ListType.apply_parameters([Int32Type])
SetOfInt = SetType.apply_parameters([Int32Type])
MapIntToStr = MapType.apply_parameters([Int32Type, UTF8Type])


# For TupleType and UserType, we need to set subtypes on a subclass
class TestTupleType(TupleType):
subtypes = (
Int32Type,
Int32Type,
Int32Type,
Int32Type,
Int32Type,
Int32Type,
Int32Type,
Int32Type,
Int32Type,
Int32Type,
)


class TestUserType(UserType):
subtypes = (Int32Type, UTF8Type, FloatType, Int32Type, UTF8Type)
fieldnames = ("id", "name", "score", "age", "email")
typename = "test_udt"
keyspace = "test_ks"
mapped_class = None
tuple_type = None


def run_bench(label, fn, args, n):
# Warm up
for _ in range(min(1000, n)):
fn(*args)
t = timeit.timeit(lambda: fn(*args), number=n)
ns_per_call = t / n * 1e9
print(f" {label:45s} {t:.3f}s ({ns_per_call:.2f} ns/call)")
return t, ns_per_call


# Test data
list_10 = list(range(10))
list_100 = list(range(100))
list_1000 = list(range(1000))
list_with_nulls = [i if i % 3 != 0 else None for i in range(100)]

map_10 = {i: f"value_{i}" for i in range(10)}
map_100 = {i: f"value_{i}" for i in range(100)}

tuple_10 = tuple(range(10))
udt_val = (1, "test_name", 3.14, 25, "test@example.com")

N_SMALL = 500_000
N_MED = 100_000
N_LARGE = 10_000

print(f"Collection serialization benchmark")
print(f"=" * 70)

results = {}

print(f"\nListType.serialize (list of int32):")
_, r = run_bench(
"10 elements", ListOfInt.serialize, (list_10, PROTOCOL_VERSION), N_SMALL
)
results["list_10"] = r
_, r = run_bench(
"100 elements", ListOfInt.serialize, (list_100, PROTOCOL_VERSION), N_MED
)
results["list_100"] = r
_, r = run_bench(
"1000 elements", ListOfInt.serialize, (list_1000, PROTOCOL_VERSION), N_LARGE
)
results["list_1000"] = r
_, r = run_bench(
"100 elements (33% null)",
ListOfInt.serialize,
(list_with_nulls, PROTOCOL_VERSION),
N_MED,
)
results["list_100_nulls"] = r

print(f"\nMapType.serialize (map<int32, text>):")
_, r = run_bench(
"10 entries", MapIntToStr.serialize, (map_10, PROTOCOL_VERSION), N_SMALL
)
results["map_10"] = r
_, r = run_bench(
"100 entries", MapIntToStr.serialize, (map_100, PROTOCOL_VERSION), N_MED
)
results["map_100"] = r

print(f"\nTupleType.serialize (10 x int32):")
_, r = run_bench(
"10 fields", TestTupleType.serialize, (tuple_10, PROTOCOL_VERSION), N_SMALL
)
results["tuple_10"] = r

print(f"\nUserType.serialize (5 fields: int, text, float, int, text):")
_, r = run_bench(
"5 fields", TestUserType.serialize, (udt_val, PROTOCOL_VERSION), N_SMALL
)
results["udt_5"] = r

# Print summary for easy comparison
print(f"\n{'=' * 70}")
print("Summary (ns/call):")
for k, v in results.items():
print(f" {k:25s}: {v:.2f} ns")
51 changes: 26 additions & 25 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@

_number_types = frozenset((int, float))

# Pre-computed null sentinel for collection element serialization (int32 -1)
_INT32_NULL = int32_pack(-1)


def _name_from_hex_string(encoded_name):
bin_str = unhexlify(encoded_name)
Expand Down Expand Up @@ -836,17 +839,16 @@ def serialize_safe(cls, items, protocol_version):
raise TypeError("Received a string for a type that expects a sequence")

subtype, = cls.subtypes
buf = io.BytesIO()
buf.write(int32_pack(len(items)))
inner_proto = max(3, protocol_version)
parts = [int32_pack(len(items))]
for item in items:
if item is None:
buf.write(int32_pack(-1))
parts.append(_INT32_NULL)
else:
itembytes = subtype.to_binary(item, inner_proto)
buf.write(int32_pack(len(itembytes)))
buf.write(itembytes)
return buf.getvalue()
parts.append(int32_pack(len(itembytes)))
parts.append(itembytes)
return b"".join(parts)


class ListType(_SimpleParameterizedType):
Expand Down Expand Up @@ -899,27 +901,26 @@ def deserialize_safe(cls, byts, protocol_version):
@classmethod
def serialize_safe(cls, themap, protocol_version):
key_type, value_type = cls.subtypes
buf = io.BytesIO()
buf.write(int32_pack(len(themap)))
try:
items = themap.items()
except AttributeError:
raise TypeError("Got a non-map object for a map value")
inner_proto = max(3, protocol_version)
parts = [int32_pack(len(themap))]
for key, val in items:
if key is not None:
keybytes = key_type.to_binary(key, inner_proto)
buf.write(int32_pack(len(keybytes)))
buf.write(keybytes)
parts.append(int32_pack(len(keybytes)))
parts.append(keybytes)
else:
buf.write(int32_pack(-1))
parts.append(_INT32_NULL)
if val is not None:
valbytes = value_type.to_binary(val, inner_proto)
buf.write(int32_pack(len(valbytes)))
buf.write(valbytes)
parts.append(int32_pack(len(valbytes)))
parts.append(valbytes)
else:
buf.write(int32_pack(-1))
return buf.getvalue()
parts.append(_INT32_NULL)
return b"".join(parts)


class TupleType(_ParameterizedType):
Expand Down Expand Up @@ -957,15 +958,15 @@ def serialize_safe(cls, val, protocol_version):
(len(cls.subtypes), len(val), val))

proto_version = max(3, protocol_version)
buf = io.BytesIO()
parts = []
for item, subtype in zip(val, cls.subtypes):
if item is not None:
packed_item = subtype.to_binary(item, proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
parts.append(int32_pack(len(packed_item)))
parts.append(packed_item)
else:
buf.write(int32_pack(-1))
return buf.getvalue()
parts.append(_INT32_NULL)
return b"".join(parts)

@classmethod
def cql_parameterized_type(cls):
Expand Down Expand Up @@ -1026,7 +1027,7 @@ def deserialize_safe(cls, byts, protocol_version):
@classmethod
def serialize_safe(cls, val, protocol_version):
proto_version = max(3, protocol_version)
buf = io.BytesIO()
parts = []
for i, (fieldname, subtype) in enumerate(zip(cls.fieldnames, cls.subtypes)):
# first treat as a tuple, else by custom type
try:
Expand All @@ -1038,11 +1039,11 @@ def serialize_safe(cls, val, protocol_version):

if item is not None:
packed_item = subtype.to_binary(item, proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
parts.append(int32_pack(len(packed_item)))
parts.append(packed_item)
else:
buf.write(int32_pack(-1))
return buf.getvalue()
parts.append(_INT32_NULL)
return b"".join(parts)

@classmethod
def _make_registered_udt_namedtuple(cls, keyspace, name, field_names):
Expand Down
Loading