diff --git a/src/s2_sdk/_mappers.py b/src/s2_sdk/_mappers.py index 1b93623..e473d57 100644 --- a/src/s2_sdk/_mappers.py +++ b/src/s2_sdk/_mappers.py @@ -340,21 +340,8 @@ def append_ack_from_proto(ack: pb.AppendAck) -> AppendAck: ) -def read_batch_from_proto( - batch: pb.ReadBatch, ignore_command_records: bool = False -) -> ReadBatch: - records = [] - for sr in batch.records: - if ignore_command_records and _is_command_record(sr): - continue - records.append( - SequencedRecord( - seq_num=sr.seq_num, - body=sr.body, - headers=[(h.name, h.value) for h in sr.headers], - timestamp=sr.timestamp, - ) - ) +def read_batch_from_proto(batch: pb.ReadBatch) -> ReadBatch: + records = [sequenced_record_from_proto(sr) for sr in batch.records] tail = None if batch.HasField("tail"): tail = StreamPosition( @@ -373,12 +360,6 @@ def sequenced_record_from_proto(sr: pb.SequencedRecord) -> SequencedRecord: ) -def _is_command_record(sr: pb.SequencedRecord) -> bool: - if len(sr.headers) == 1 and sr.headers[0].name == b"": - return True - return False - - def read_start_params(start: _ReadStart) -> dict[str, Any]: if isinstance(start, SeqNum): return {"seq_num": start.value} diff --git a/src/s2_sdk/_ops.py b/src/s2_sdk/_ops.py index e9759fd..f91fa0a 100644 --- a/src/s2_sdk/_ops.py +++ b/src/s2_sdk/_ops.py @@ -976,7 +976,13 @@ async def read( proto_batch = pb.ReadBatch() proto_batch.ParseFromString(response.content) - return read_batch_from_proto(proto_batch, ignore_command_records) + batch = read_batch_from_proto(proto_batch) + if ignore_command_records: + batch = types.ReadBatch( + records=[r for r in batch.records if not r.is_command_record()], + tail=batch.tail, + ) + return batch @fallible async def read_session( diff --git a/src/s2_sdk/_s2s/_read_session.py b/src/s2_sdk/_s2s/_read_session.py index c7aa629..9d0330b 100644 --- a/src/s2_sdk/_s2s/_read_session.py +++ b/src/s2_sdk/_s2s/_read_session.py @@ -81,7 +81,7 @@ async def run_read_session( proto_batch = pb.ReadBatch() proto_batch.ParseFromString(message_body) - batch = read_batch_from_proto(proto_batch, ignore_command_records) + batch = read_batch_from_proto(proto_batch) if batch.tail is not None: last_tail_at = time.monotonic() @@ -103,7 +103,16 @@ async def run_read_session( ) params["bytes"] = remaining_bytes - yield batch + if ignore_command_records: + batch = ReadBatch( + records=[ + r for r in batch.records if not r.is_command_record() + ], + tail=batch.tail, + ) + + if batch.records: + yield batch return except Exception as e: diff --git a/src/s2_sdk/_types.py b/src/s2_sdk/_types.py index 2b68333..8e1fb8f 100644 --- a/src/s2_sdk/_types.py +++ b/src/s2_sdk/_types.py @@ -242,6 +242,10 @@ class SequencedRecord: timestamp: int """Timestamp for this record.""" + def is_command_record(self) -> bool: + """Check if this is a command record.""" + return len(self.headers) == 1 and self.headers[0][0] == b"" + @dataclass(slots=True) class ReadBatch: