|
| 1 | +"""A module with various new assertion functions. |
| 2 | +
|
| 3 | +Index |
| 4 | +----- |
| 5 | +.. currentmodule:: assertionlib.assertion_functions |
| 6 | +.. autosummary:: |
| 7 | + len_eq |
| 8 | + str_eq |
| 9 | + shape_eq |
| 10 | + isdisjoint |
| 11 | + function_eq |
| 12 | +
|
| 13 | +API |
| 14 | +--- |
| 15 | +.. autofunction:: len_eq |
| 16 | +.. autofunction:: str_eq |
| 17 | +.. autofunction:: shape_eq |
| 18 | +.. autofunction:: isdisjoint |
| 19 | +.. autofunction:: function_eq |
| 20 | +
|
| 21 | +""" |
| 22 | + |
| 23 | +import dis |
| 24 | +from types import FunctionType |
| 25 | +from itertools import zip_longest |
| 26 | +from typing import ( |
| 27 | + Sized, |
| 28 | + Callable, |
| 29 | + TypeVar, |
| 30 | + Iterable, |
| 31 | + Union, |
| 32 | + Tuple, |
| 33 | + Hashable, |
| 34 | + TYPE_CHECKING |
| 35 | +) |
| 36 | + |
| 37 | +from .functions import to_positional |
| 38 | + |
| 39 | +if TYPE_CHECKING: |
| 40 | + from numpy import ndarray # type: ignore |
| 41 | +else: |
| 42 | + ndarray = 'numpy.ndarray' |
| 43 | + |
| 44 | +__all__ = ['len_eq', 'str_eq', 'shape_eq', 'isdisjoint', 'function_eq'] |
| 45 | + |
| 46 | +T = TypeVar('T') |
| 47 | +IT = TypeVar('IT', bound=Union[None, dis.Instruction]) |
| 48 | + |
| 49 | + |
| 50 | +@to_positional |
| 51 | +def len_eq(a: Sized, b: int) -> bool: |
| 52 | + """Check if the length of **a** is equivalent to **b**: :code:`len(a) == b`. |
| 53 | +
|
| 54 | + Parameters |
| 55 | + ---------- |
| 56 | + a : :class:`~collections.abc.Sized` |
| 57 | + The object whose size will be evaluated. |
| 58 | +
|
| 59 | + b : :class:`int` |
| 60 | + The integer that will be matched against the size of **a**. |
| 61 | +
|
| 62 | + """ |
| 63 | + return len(a) == b |
| 64 | + |
| 65 | + |
| 66 | +@to_positional |
| 67 | +def str_eq(a: T, b: str, *, str_converter: Callable[[T], str] = repr) -> bool: |
| 68 | + """Check if the string-representation of **a** is equivalent to **b**: :code:`repr(a) == b`. |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | + a : :class:`T<typing.TypeVar>` |
| 73 | + An object whose string represention will be evaluated. |
| 74 | +
|
| 75 | + b : :class:`str` |
| 76 | + The string that will be matched against the string-output of **a**. |
| 77 | +
|
| 78 | + Keyword Arguments |
| 79 | + ----------------- |
| 80 | + str_converter : :data:`Callable[[T], str]<typing.Callable>` |
| 81 | + The callable for constructing **a**'s string representation. |
| 82 | + Uses :func:`repr` by default. |
| 83 | +
|
| 84 | + """ |
| 85 | + return str_converter(a) == b |
| 86 | + |
| 87 | + |
| 88 | +@to_positional |
| 89 | +def shape_eq(a: ndarray, b: Union[ndarray, Tuple[int, ...]]) -> bool: |
| 90 | + """Check if the shapes of **a** and **b** are equivalent: :code:`a.shape == getattr(b, 'shape', b)`. |
| 91 | +
|
| 92 | + **b** should be either an object with the ``shape`` attribute (*e.g.* a NumPy array) |
| 93 | + or a :class:`tuple` representing a valid array shape. |
| 94 | +
|
| 95 | + Parameters |
| 96 | + ---------- |
| 97 | + a : :class:`numpy.ndarray` |
| 98 | + A NumPy array. |
| 99 | +
|
| 100 | + b : :class:`numpy.ndarray` or :class:`tuple` [:class:`int`, ...] |
| 101 | + A NumPy array or a tuple of integers representing the shape of **a**. |
| 102 | +
|
| 103 | + """ # noqa |
| 104 | + return a.shape == getattr(b, 'shape', b) |
| 105 | + |
| 106 | + |
| 107 | +@to_positional |
| 108 | +def isdisjoint(a: Iterable[Hashable], b: Iterable[Hashable]) -> bool: |
| 109 | + """Check if **a** has no elements in **b**. |
| 110 | +
|
| 111 | + Parameters |
| 112 | + ---------- |
| 113 | + a/b : :class:`~collections.abc.Iterable` [:class:`~collections.abc.Hashable`] |
| 114 | + Two to-be compared iterables. |
| 115 | + Note that both iterables must consist of hashable objects. |
| 116 | +
|
| 117 | + See Also |
| 118 | + -------- |
| 119 | + :meth:`set.isdisjoint()<frozenset.isdisjoint>` |
| 120 | + Return ``True`` if two sets have a null intersection. |
| 121 | +
|
| 122 | + """ |
| 123 | + try: |
| 124 | + return a.isdisjoint(b) # type: ignore |
| 125 | + |
| 126 | + # **a** does not have the isdisjoint method |
| 127 | + except AttributeError: |
| 128 | + return set(a).isdisjoint(b) |
| 129 | + |
| 130 | + # **a.isdisjoint** is not a callable or |
| 131 | + # **a** and/or **b** do not consist of hashable elements |
| 132 | + except TypeError as ex: |
| 133 | + if callable(a.isdisjoint): # type: ignore |
| 134 | + raise ex |
| 135 | + return set(a).isdisjoint(b) |
| 136 | + |
| 137 | + |
| 138 | +@to_positional |
| 139 | +def function_eq(func1: FunctionType, func2: FunctionType) -> bool: |
| 140 | + """Check if two functions are equivalent by checking if their :attr:`__code__` is identical. |
| 141 | +
|
| 142 | + **func1** and **func2** should be instances of :data:`~types.FunctionType` |
| 143 | + or any other object with access to the :attr:`__code__` attribute. |
| 144 | +
|
| 145 | + Parameters |
| 146 | + ---------- |
| 147 | + func1/func2 : :data:`~types.FunctionType` |
| 148 | + Two functions. |
| 149 | +
|
| 150 | + Examples |
| 151 | + -------- |
| 152 | + .. code:: python |
| 153 | +
|
| 154 | + >>> from assertionlib.assertion_functions import function_eq |
| 155 | +
|
| 156 | + >>> func1 = lambda x: x + 5 |
| 157 | + >>> func2 = lambda x: x + 5 |
| 158 | + >>> func3 = lambda x: 5 + x |
| 159 | +
|
| 160 | + >>> print(function_eq(func1, func2)) |
| 161 | + True |
| 162 | +
|
| 163 | + >>> print(function_eq(func1, func3)) |
| 164 | + False |
| 165 | +
|
| 166 | + """ |
| 167 | + code1 = None |
| 168 | + try: |
| 169 | + code1 = func1.__code__ |
| 170 | + code2 = func2.__code__ |
| 171 | + except AttributeError as ex: |
| 172 | + name, obj = ('func1', func1) if code1 is None else ('func2', func2) |
| 173 | + raise TypeError(f"{name!r} expected a function or object with the '__code__' attribute; " |
| 174 | + f"observed type: {obj.__class__.__name__!r}") from ex |
| 175 | + |
| 176 | + iterator = zip_longest(dis.get_instructions(code1), dis.get_instructions(code2)) |
| 177 | + tup_iter = ((_sanitize_instruction(i), _sanitize_instruction(j)) for i, j in iterator) |
| 178 | + return all([i == j for i, j in tup_iter]) |
| 179 | + |
| 180 | + |
| 181 | +def _sanitize_instruction(instruction: IT) -> IT: |
| 182 | + """Sanitize the supplied instruction by setting :attr:`~dis.Instruction.starts_line` to :data:`None`.""" # noqa |
| 183 | + if instruction is None: |
| 184 | + return None |
| 185 | + return instruction._replace(starts_line=None) # type: ignore |
0 commit comments