Source code for pystructs.validate

"""Validation system for pystructs."""

from __future__ import annotations

import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List

from pystructs.exceptions import InconsistencyError, ValidationError
from pystructs.expressions import Expression

if TYPE_CHECKING:
    from pystructs.struct import Struct

__all__ = (
    "Validator",
    "FieldValidator",
    "Range",
    "OneOf",
    "Regex",
    "BytePattern",
    "Consistency",
    "Custom",
)


class Validator(ABC):
    """Abstract base class for struct-level validators.

    Validators are run when validate() is called on a struct.
    """

    @abstractmethod
    def validate(self, instance: Struct) -> None:
        """Validate the struct instance.

        Args:
            instance: The struct instance to validate

        Raises:
            ValidationError: If validation fails
        """
        pass


class FieldValidator(ABC):
    """Abstract base class for field-level validators.

    Field validators are run for individual field values.
    """

    @abstractmethod
    def __call__(self, value: Any, instance: Struct) -> None:
        """Validate a field value.

        Args:
            value: The field value to validate
            instance: The struct instance (for cross-field validation)

        Raises:
            ValidationError: If validation fails
        """
        pass


# === Field-level validators ===


[docs] class Range(FieldValidator): """Validate that a numeric value is within a range. Examples: >>> version = UInt8(validators=[Range(1, 10)]) """ def __init__(self, min_val: float | None = None, max_val: float | None = None): """Initialize a Range validator. Args: min_val: Minimum allowed value (inclusive) max_val: Maximum allowed value (inclusive) """ self.min_val = min_val self.max_val = max_val def __call__(self, value: Any, instance: Struct) -> None: if self.min_val is not None and value < self.min_val: raise ValidationError(f"Value {value} is less than minimum {self.min_val}") if self.max_val is not None and value > self.max_val: raise ValidationError( f"Value {value} is greater than maximum {self.max_val}" ) def __repr__(self) -> str: return f"Range({self.min_val}, {self.max_val})"
[docs] class OneOf(FieldValidator): """Validate that a value is one of allowed choices. Examples: >>> msg_type = UInt8(validators=[OneOf([1, 2, 3])]) """ def __init__(self, choices: List[Any]): """Initialize a OneOf validator. Args: choices: List of allowed values """ self.choices = choices def __call__(self, value: Any, instance: Struct) -> None: if value not in self.choices: raise ValidationError( f"Value {value} is not in allowed choices {self.choices}" ) def __repr__(self) -> str: return f"OneOf({self.choices!r})"
[docs] class Regex(FieldValidator): """Validate that a string matches a regex pattern. Examples: >>> name = String(length=10, validators=[Regex(r'^[a-zA-Z]+$')]) """ def __init__(self, pattern: str): """Initialize a Regex validator. Args: pattern: Regular expression pattern """ self.pattern = re.compile(pattern) def __call__(self, value: Any, instance: Struct) -> None: if not self.pattern.match(value): raise ValidationError( f"Value '{value}' does not match pattern {self.pattern.pattern}" ) def __repr__(self) -> str: return f"Regex({self.pattern.pattern!r})"
[docs] class BytePattern(FieldValidator): """Validate that bytes start with a specific pattern. Examples: >>> magic = FixedBytes(4, validators=[BytePattern(b'\\x89PNG')]) """ def __init__(self, pattern: bytes): """Initialize a BytePattern validator. Args: pattern: Expected byte pattern at start """ self.pattern = pattern def __call__(self, value: Any, instance: Struct) -> None: if not value.startswith(self.pattern): raise ValidationError( f"Bytes do not start with expected pattern {self.pattern!r}" ) def __repr__(self) -> str: return f"BytePattern({self.pattern!r})"
# === Struct-level validators ===
[docs] class Consistency(Validator): """Validate consistency between fields. Compares a field's value against an expression. Examples: >>> class Packet(Struct): ... class Meta: ... validators = [ ... Consistency('payload_size', equals=Len('payload')), ... Consistency('checksum', equals=Checksum('payload', 'crc32')), ... ] """ def __init__( self, field: str, equals: Expression | None = None, greater_than: Expression | None = None, less_than: Expression | None = None, ): """Initialize a Consistency validator. Args: field: Field name to validate equals: Expected value (as Expression) greater_than: Value must be greater than this (as Expression) less_than: Value must be less than this (as Expression) """ self.field = field self.equals = equals self.greater_than = greater_than self.less_than = less_than
[docs] def validate(self, instance: Struct) -> None: actual = getattr(instance, self.field) if self.equals is not None: expected = self.equals.evaluate(instance) if actual != expected: raise InconsistencyError( field=self.field, actual=actual, expected=expected, ) if self.greater_than is not None: threshold = self.greater_than.evaluate(instance) if not (actual > threshold): raise ValidationError( f"Field '{self.field}': {actual} is not > {threshold}" ) if self.less_than is not None: threshold = self.less_than.evaluate(instance) if not (actual < threshold): raise ValidationError( f"Field '{self.field}': {actual} is not < {threshold}" )
def __repr__(self) -> str: parts = [f"field={self.field!r}"] if self.equals: parts.append(f"equals={self.equals!r}") if self.greater_than: parts.append(f"greater_than={self.greater_than!r}") if self.less_than: parts.append(f"less_than={self.less_than!r}") return f"Consistency({', '.join(parts)})"
[docs] class Custom(Validator): """Custom validation function. Examples: >>> class Packet(Struct): ... class Meta: ... validators = [ ... Custom(lambda self: self.version in (1, 2, 3), ... "Unsupported version"), ... ] """ def __init__( self, func: Callable[[Struct], bool], message: str = "Custom validation failed", ): """Initialize a Custom validator. Args: func: Validation function that returns True if valid message: Error message if validation fails """ self.func = func self.message = message
[docs] def validate(self, instance: Struct) -> None: if not self.func(instance): raise ValidationError(self.message)
def __repr__(self) -> str: return f"Custom({self.message!r})"