"""Struct class and metaclass."""
from __future__ import annotations
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from io import BytesIO
from typing import Any, BinaryIO, Dict, List, Optional, Type, TypeVar
from pystructs.base import BaseField
from pystructs.config import Endian
from pystructs.exceptions import (
FieldValidationError,
SerializationError,
TrailingDataError,
ValidationErrors,
)
__all__ = (
"StructMeta",
"StructOptions",
"Struct",
)
T = TypeVar("T", bound="Struct")
[docs]
@dataclass
class StructOptions:
"""Configuration options for a Struct class.
These are set via the inner Meta class.
"""
endian: str = field(default=Endian.LITTLE)
trailing_data: str = field(default="error") # 'error', 'warn', 'ignore'
sync_rules: List[Any] = field(default_factory=list)
validators: List[Any] = field(default_factory=list)
[docs]
def inherit_from(self, parent: StructOptions) -> None:
"""Inherit options from a parent Struct.
Args:
parent: Parent struct's options
"""
self.endian = parent.endian
self.trailing_data = parent.trailing_data
self.sync_rules = parent.sync_rules.copy()
self.validators = parent.validators.copy()
[docs]
class Struct(metaclass=StructMeta):
"""Base class for binary structures.
Define fields as class attributes and use parse() to create instances
from binary data, or to_bytes() to serialize instances.
Example:
>>> class Packet(Struct):
... class Meta:
... endian = 'big'
... magic = UInt32(default=0xDEADBEEF)
... version = UInt8(default=1)
... length = UInt16()
...
>>> packet = Packet.parse(raw_bytes)
>>> packet.version
1
>>> packet.to_bytes()
b'...'
"""
_fields: OrderedDict[str, BaseField]
_meta: StructOptions
def __init__(self, _raw: bytes | None = None, **kwargs: Any):
"""Initialize a struct instance.
Args:
_raw: Original raw bytes (internal use for parsing)
**kwargs: Field values to initialize
"""
self._data: Dict[str, Any] = {}
self._raw = _raw
self._parent: Optional[Struct] = None
self._root: Struct = self
# Initialize fields
for name, field_obj in self._fields.items():
if name in kwargs:
self._data[name] = kwargs[name]
elif field_obj.default is not None:
if callable(field_obj.default):
self._data[name] = field_obj.default()
else:
self._data[name] = field_obj.default
elif not field_obj.required:
self._data[name] = None
# If required and no default, value must be provided before to_bytes()
# === Class methods ===
[docs]
@classmethod
def parse(
cls: Type[T],
data: bytes,
allow_trailing: bool = False,
) -> T:
"""Parse binary data into a struct instance.
Args:
data: Binary data to parse
allow_trailing: If True, ignore trailing bytes after parsing
Returns:
Parsed struct instance
Raises:
ParseError: If parsing fails
TrailingDataError: If trailing data exists and not allowed
"""
stream = BytesIO(data)
instance = cls._parse_stream(stream)
instance._raw = data
# Handle trailing data
remaining = stream.read()
if remaining and not allow_trailing:
policy = cls._meta.trailing_data
if policy == "error":
raise TrailingDataError(len(remaining))
elif policy == "warn":
warnings.warn(f"Ignoring {len(remaining)} trailing bytes")
return instance
@classmethod
def _parse_stream(
cls: Type[T],
stream: BinaryIO,
parent: Struct | None = None,
) -> T:
"""Parse from a stream (internal use).
Args:
stream: Binary stream to read from
parent: Parent struct for nested structs
Returns:
Parsed struct instance
"""
instance = cls.__new__(cls)
instance._data = {}
instance._raw = None
instance._parent = parent
instance._root = parent._root if parent else instance
for name, field_obj in cls._fields.items():
value = field_obj.parse(stream, instance)
instance._data[name] = value
return instance
[docs]
@classmethod
def get_fixed_size(cls) -> Optional[int]:
"""Get the fixed size of this struct if all fields are fixed-size.
Returns:
Total size in bytes, or None if size is variable
"""
from pystructs.base import FixedField
total = 0
for field_obj in cls._fields.values():
if isinstance(field_obj, FixedField):
total += field_obj.size
else:
return None
return total
# === Instance methods ===
[docs]
def to_bytes(
self,
sync: bool = False,
validate: bool = False,
) -> bytes:
"""Serialize the struct to bytes.
Args:
sync: If True, run sync() before serializing
validate: If True, run validate() before serializing
Returns:
Serialized bytes
Raises:
SerializationError: If serialization fails
ValidationError: If validation fails (when validate=True)
"""
if sync:
self.sync()
if validate:
self.validate()
buffer = BytesIO()
for name, field_obj in self._fields.items():
value = self._data.get(name)
if value is None and field_obj.required:
raise SerializationError(field=name, reason="Missing required field")
data = field_obj.serialize(value, self)
buffer.write(data)
return buffer.getvalue()
[docs]
def sync(self, fields: List[str] | None = None) -> Struct:
"""Run synchronization rules to update field values.
Args:
fields: Specific fields to sync (default: all)
Returns:
self (for method chaining)
"""
for rule in self._meta.sync_rules:
if fields is None or rule.target in fields:
rule.apply(self)
return self
[docs]
def validate(self) -> Struct:
"""Run all validation rules.
Returns:
self (for method chaining)
Raises:
ValidationErrors: If any validation fails
"""
errors = []
# Field-level validation
for name, field_obj in self._fields.items():
value = self._data.get(name)
for validator in field_obj.validators:
try:
validator(value, self)
except Exception as e:
errors.append(FieldValidationError(name, e))
# Struct-level validation
for validator in self._meta.validators:
try:
validator.validate(self)
except Exception as e:
errors.append(e)
if errors:
raise ValidationErrors(errors)
return self
[docs]
def get_size(self) -> int:
"""Calculate the current serialization size.
Returns:
Size in bytes
"""
total = 0
for name, field_obj in self._fields.items():
total += field_obj.get_size(self)
return total
# === Attribute access ===
def __getattr__(self, name: str) -> Any:
"""Get field value by attribute access.
Args:
name: Attribute name
Returns:
Field value
Raises:
AttributeError: If not a valid field
"""
if name.startswith("_"):
raise AttributeError(name)
if name in self.__class__._fields:
return self._data.get(name)
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
"""Set field value by attribute access.
Args:
name: Attribute name
value: Value to set
"""
if name.startswith("_"):
super().__setattr__(name, value)
elif name in self.__class__._fields:
self._data[name] = value
else:
super().__setattr__(name, value)
# === Utilities ===
def __repr__(self) -> str:
fields_str = ", ".join(
f"{name}={getattr(self, name)!r}" for name in self._fields
)
return f"{self.__class__.__name__}({fields_str})"
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return self._data == other._data
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to a dictionary.
Returns:
Dictionary with field names as keys
"""
result = {}
for name in self._fields:
value = self._data.get(name)
if isinstance(value, Struct):
value = value.to_dict()
elif isinstance(value, list):
value = [v.to_dict() if isinstance(v, Struct) else v for v in value]
result[name] = value
return result