Source code for pubsub.message

"""Message class for pubsub library."""

import io
import json
import random
import struct
import time
from typing import BinaryIO

# Message serialization format version
MESSAGE_FORMAT_VERSION = 1
SUPPORTED_MESSAGE_VERSIONS = [MESSAGE_FORMAT_VERSION]

# Magic number to identify message format ("PMSG" in ASCII)
MESSAGE_MAGIC_NUMBER = 0x504D5347

# Scalar types allowed as header values
HeaderValueTypes = str | int | float | bool | None
"""Type alias for valid header value types. Headers can contain strings, integers,
floats, booleans, or None as values."""

# Dictionary type for message headers
Header = dict[str, HeaderValueTypes]
"""Type alias for message headers. A dictionary with string keys and scalar values
(str, int, float, bool, or None). Used to store metadata about messages."""

[docs] class Message: """Represents a message in the pubsub system."""
[docs] def __init__(self, topic: str, content: bytes, headers: Header | None = None): """ Initialize a new message. Args: topic: The topic this message belongs to content: The message payload as bytes headers: Optional dictionary of string key-value pairs for metadata """ self.id = self._next_id() self.timestamp = int(time.time() * 1_000_000) # microseconds since epoch self.topic = topic self.content = content self.headers = headers if headers is not None else {}
@staticmethod def _next_id() -> int: """ Generate a unique message ID based on current time in microseconds with the least significant 16 bits replaced by random bits. This provides both time-based ordering and uniqueness even when multiple messages are created within the same microsecond. Returns: A 64-bit integer ID """ time_micro = int(time.time() * 1_000_000) high_bits = (time_micro >> 16) << 16 # Keep all but the last 16 bits random_bits = random.randint(0, 0xFFFF) return high_bits | random_bits
[docs] def write(self, stream: BinaryIO) -> None: """ Write the message to a binary stream. Binary format: - 4 bytes: magic number (0x504D5347 - "PMSG") - 1 byte: format version (uint8) - 8 bytes: id (uint64) - 8 bytes: timestamp (uint64, microseconds since epoch) - 4 bytes: topic length (uint32) - N bytes: topic (UTF-8 encoded) - 4 bytes: headers JSON length (uint32) - N bytes: headers as JSON string (UTF-8 encoded) - 4 bytes: content length (uint32) - N bytes: content Args: stream: Binary stream to write to """ # Encode strings as UTF-8 topic_bytes = self.topic.encode("utf-8") headers_json = json.dumps(self.headers, ensure_ascii=False) headers_bytes = headers_json.encode("utf-8") # Write each component to stream stream.write(struct.pack("!I", MESSAGE_MAGIC_NUMBER)) stream.write(struct.pack("!B", MESSAGE_FORMAT_VERSION)) stream.write(struct.pack("!Q", self.id)) stream.write(struct.pack("!Q", self.timestamp)) stream.write(struct.pack("!I", len(topic_bytes))) stream.write(topic_bytes) stream.write(struct.pack("!I", len(headers_bytes))) stream.write(headers_bytes) stream.write(struct.pack("!I", len(self.content))) stream.write(self.content)
@staticmethod def _read_exact(stream: BinaryIO, n: int) -> bytes: """ Read exactly n bytes from stream, handling partial reads gracefully. Continues reading until all requested bytes are received or EOF is encountered. This handles cases where streams (sockets, pipes) return data in chunks. Args: stream: Stream to read from n: Number of bytes to read Returns: Exactly n bytes Raises: ValueError: If EOF is reached before reading all bytes """ chunks = [] bytes_read = 0 while bytes_read < n: chunk = stream.read(n - bytes_read) if not chunk: # EOF reached - no more data available raise ValueError(f"Expected {n} bytes, but only read {bytes_read} bytes (EOF)") chunks.append(chunk) bytes_read += len(chunk) return b"".join(chunks)
[docs] @classmethod def read(cls, stream: BinaryIO) -> "Message": """ Read and deserialize a message from a binary stream. Args: stream: Binary stream to read from Returns: A new Message instance """ # Read and validate magic number magic_data = cls._read_exact(stream, 4) magic = struct.unpack("!I", magic_data)[0] if magic != MESSAGE_MAGIC_NUMBER: raise ValueError( f"Invalid magic number 0x{magic:08X}, expected 0x{MESSAGE_MAGIC_NUMBER:08X}. This" "data is not a valid message." ) # Read and validate format version version_data = cls._read_exact(stream, 1) version = struct.unpack("!B", version_data)[0] if version not in SUPPORTED_MESSAGE_VERSIONS: raise ValueError( f"Unsupported message format version {version}, expected one of " f"{SUPPORTED_MESSAGE_VERSIONS}." ) # Read id id_data = cls._read_exact(stream, 8) message_id = struct.unpack("!Q", id_data)[0] # Read timestamp timestamp_data = cls._read_exact(stream, 8) message_timestamp = struct.unpack("!Q", timestamp_data)[0] # Read topic topic_length_data = cls._read_exact(stream, 4) topic_length = struct.unpack("!I", topic_length_data)[0] topic_data = cls._read_exact(stream, topic_length) topic = topic_data.decode("utf-8") # Read headers headers_length_data = cls._read_exact(stream, 4) headers_length = struct.unpack("!I", headers_length_data)[0] headers_data = cls._read_exact(stream, headers_length) headers_json = headers_data.decode("utf-8") headers = json.loads(headers_json) if headers_json else {} # Read data data_length_data = cls._read_exact(stream, 4) data_length = struct.unpack("!I", data_length_data)[0] message_content = cls._read_exact(stream, data_length) # Create new instance and set the id and timestamp directly message = cls(topic=topic, content=message_content, headers=headers) message.id = message_id message.timestamp = message_timestamp return message
[docs] def to_bytes(self) -> bytes: """ Convenience method to serialize message to bytes. Returns: The serialized message as bytes """ stream = io.BytesIO() self.write(stream) return stream.getvalue()
[docs] @classmethod def from_bytes(cls, data: bytes) -> "Message": """ Convenience method to deserialize message from bytes. Args: data: The serialized message bytes Returns: A new Message instance """ stream = io.BytesIO(data) return cls.read(stream)
def __repr__(self) -> str: return f"Message(id={self.id}, topic='{self.topic}', content_length={len(self.content)})"