mk_assert.mk_assert

  1from typing import Callable, Any, Optional
  2
  3import logging
  4
  5from . import print_helper
  6
  7
  8# Global test state -------------------------------------------------------------------#
  9g_tests: list["TestFn"] = []
 10g_active_test: Optional["ActiveTestContext"] = None
 11g_setup_fn: Optional[Callable[[], None]] = None
 12g_teardown_fn: Optional[Callable[[], None]] = None
 13
 14
 15class TestFn:
 16    def __init__(
 17        self, func: Callable[..., None], args: tuple[Any, ...], kwargs: dict[str, Any]
 18    ):
 19        """
 20        :param func: The test function to be called
 21        :param args: Positional arguments to pass to the test function
 22        :param kwargs: Keyword arguments to pass to the test function
 23        """
 24        self.func: Callable[..., None] = func
 25        self.args: tuple[Any, ...] = args
 26        self.kwargs: dict[str, Any] = kwargs
 27
 28    def run(self) -> None:
 29        """
 30        Run the test function with the stored arguments.
 31        """
 32        return self.func(*self.args, **self.kwargs)
 33
 34
 35class ActiveTestContext:
 36    def __init__(self, test_fn: TestFn):
 37        """
 38        :param test_fn: The TestFn instance representing the active test.
 39        """
 40        self._test_fn = test_fn
 41        self.passed: int = 0
 42        self.failed: int = 0
 43
 44    def __enter__(self):
 45        global g_active_test
 46        g_active_test = self
 47        return self
 48
 49    def __exit__(self, exc_type, exc_value, _traceback):
 50        global g_active_test
 51        g_active_test = None
 52
 53        if exc_type is not None:
 54            self.failure()
 55            logging.critical(
 56                f"Test '{self._test_fn.func.__name__}' raised an exception: {exc_value}"
 57            )
 58
 59        return False
 60
 61    def success(self):
 62        """
 63        Record a successful assertion.
 64        """
 65        self.passed += 1
 66
 67    def failure(self):
 68        """
 69        Record a failed assertion.
 70        """
 71        self.failed += 1
 72
 73
 74def set_setup_fn(setup_fn: Callable[[], None]) -> None:
 75    """
 76    Set a global setup function to be called before each test.
 77
 78    :param setup_fn: The setup function to set.
 79    """
 80    global g_setup_fn
 81    g_setup_fn = setup_fn
 82
 83
 84def set_teardown_fn(teardown_fn: Callable[[], None]) -> None:
 85    """
 86    Set a global teardown function to be called after each test.
 87
 88    :param teardown_fn: The teardown function to set.
 89    """
 90    global g_teardown_fn
 91    g_teardown_fn = teardown_fn
 92
 93
 94def add_test(func, *args, run_now: bool = False, **kwargs):
 95    """
 96    Register a test function to be run later or immediately.
 97
 98    :param func: The test function to register.
 99    :param args: Positional arguments to pass to the test function.
100    :param run_now: If True, run the test immediately instead of registering it.
101    :param kwargs: Keyword arguments to pass to the test function.
102    """
103
104    global g_tests
105
106    logging.debug(f"Adding test: {func.__name__}, run_now={run_now}")
107    test_fn = TestFn(func, args, kwargs)
108    if run_now:
109        _run_single_test(test_fn)
110    else:
111        g_tests.append(test_fn)
112
113
114def _run_single_test(test_fn: TestFn) -> None:
115    """
116    Run a single test function within an active test context.
117
118    :param test_fn: The TestFn instance representing the test to run.
119    """
120    global g_setup_fn, g_teardown_fn
121    with ActiveTestContext(test_fn) as active_test:
122        if g_setup_fn is not None:
123            logging.debug(
124                f"Running setup function before test: {test_fn.func.__name__}"
125            )
126            g_setup_fn()
127
128        logging.debug(f"Running test: {test_fn.func.__name__}")
129
130        print_helper.print_test_start(test_fn.func.__name__)
131        test_fn.run()
132        print_helper.print_test_summary(
133            test_fn.func.__name__,
134            active_test.passed,
135            active_test.failed,
136        )
137
138        if g_teardown_fn is not None:
139            logging.debug(
140                f"Running teardown function after test: {test_fn.func.__name__}"
141            )
142            g_teardown_fn()
143
144
145def run_tests():
146    """
147    Run all registered tests.
148    """
149
150    global g_tests
151
152    logging.debug(f"Running {len(g_tests)} tests...")
153    for test_fn in g_tests:
154        _run_single_test(test_fn)
155        print()
156
157
158def clear_tests():
159    """
160    Clear all registered tests.
161    """
162
163    global g_tests
164    logging.debug("Clearing all registered tests.")
165    g_tests.clear()
166
167
168def assert_true(cond: bool, msg: str = "", negate: bool = False):
169    """
170    Assert that a condition is true (or false if negate is True).
171
172    :param cond: The condition to check.
173    :param msg: An optional message to display with the assertion result.
174    :param negate: If True, assert that the condition is false.
175    """
176
177    global g_active_test
178    if g_active_test is None:
179        raise RuntimeError("No active test context for assertion.")
180
181    if cond != negate:
182        g_active_test.success()
183    else:
184        g_active_test.failure()
185
186    print_helper.print_assert(msg, cond != negate)
187
188def assert_false(cond: bool, msg: str = "", negate: bool = False):
189    """
190    Assert that a condition is false (or true if negate is True).
191
192    :param cond: The condition to check.
193    :param msg: An optional message to display with the assertion result.
194    :param negate: If True, assert that the condition is true.
195    """
196    assert_true(not cond, msg, negate)
197
198
199def assert_eqf(a: float, b: float, tol: float, msg: str = "", negate: bool = False):
200    """
201    Assert that two floating-point numbers are equal within a tolerance.
202
203    :param a: The first floating-point number.
204    :param b: The second floating-point number.
205    :param tol: The tolerance within which the two numbers are considered equal.
206    :param msg: An optional message to display with the assertion result.
207    :param negate: If True, assert that the two numbers are not equal within the tolerance.
208    """
209    assert_true(abs(a - b) <= tol, msg, negate)
g_tests: list[TestFn] = []
g_active_test: Optional[ActiveTestContext] = None
g_setup_fn: Optional[Callable[[], NoneType]] = None
g_teardown_fn: Optional[Callable[[], NoneType]] = None
class TestFn:
16class TestFn:
17    def __init__(
18        self, func: Callable[..., None], args: tuple[Any, ...], kwargs: dict[str, Any]
19    ):
20        """
21        :param func: The test function to be called
22        :param args: Positional arguments to pass to the test function
23        :param kwargs: Keyword arguments to pass to the test function
24        """
25        self.func: Callable[..., None] = func
26        self.args: tuple[Any, ...] = args
27        self.kwargs: dict[str, Any] = kwargs
28
29    def run(self) -> None:
30        """
31        Run the test function with the stored arguments.
32        """
33        return self.func(*self.args, **self.kwargs)
TestFn( func: Callable[..., NoneType], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any])
17    def __init__(
18        self, func: Callable[..., None], args: tuple[Any, ...], kwargs: dict[str, Any]
19    ):
20        """
21        :param func: The test function to be called
22        :param args: Positional arguments to pass to the test function
23        :param kwargs: Keyword arguments to pass to the test function
24        """
25        self.func: Callable[..., None] = func
26        self.args: tuple[Any, ...] = args
27        self.kwargs: dict[str, Any] = kwargs
Parameters
  • func: The test function to be called
  • args: Positional arguments to pass to the test function
  • kwargs: Keyword arguments to pass to the test function
func: Callable[..., NoneType]
args: tuple[typing.Any, ...]
kwargs: dict[str, typing.Any]
def run(self) -> None:
29    def run(self) -> None:
30        """
31        Run the test function with the stored arguments.
32        """
33        return self.func(*self.args, **self.kwargs)

Run the test function with the stored arguments.

class ActiveTestContext:
36class ActiveTestContext:
37    def __init__(self, test_fn: TestFn):
38        """
39        :param test_fn: The TestFn instance representing the active test.
40        """
41        self._test_fn = test_fn
42        self.passed: int = 0
43        self.failed: int = 0
44
45    def __enter__(self):
46        global g_active_test
47        g_active_test = self
48        return self
49
50    def __exit__(self, exc_type, exc_value, _traceback):
51        global g_active_test
52        g_active_test = None
53
54        if exc_type is not None:
55            self.failure()
56            logging.critical(
57                f"Test '{self._test_fn.func.__name__}' raised an exception: {exc_value}"
58            )
59
60        return False
61
62    def success(self):
63        """
64        Record a successful assertion.
65        """
66        self.passed += 1
67
68    def failure(self):
69        """
70        Record a failed assertion.
71        """
72        self.failed += 1
ActiveTestContext(test_fn: TestFn)
37    def __init__(self, test_fn: TestFn):
38        """
39        :param test_fn: The TestFn instance representing the active test.
40        """
41        self._test_fn = test_fn
42        self.passed: int = 0
43        self.failed: int = 0
Parameters
  • test_fn: The TestFn instance representing the active test.
passed: int
failed: int
def success(self):
62    def success(self):
63        """
64        Record a successful assertion.
65        """
66        self.passed += 1

Record a successful assertion.

def failure(self):
68    def failure(self):
69        """
70        Record a failed assertion.
71        """
72        self.failed += 1

Record a failed assertion.

def set_setup_fn(setup_fn: Callable[[], NoneType]) -> None:
75def set_setup_fn(setup_fn: Callable[[], None]) -> None:
76    """
77    Set a global setup function to be called before each test.
78
79    :param setup_fn: The setup function to set.
80    """
81    global g_setup_fn
82    g_setup_fn = setup_fn

Set a global setup function to be called before each test.

Parameters
  • setup_fn: The setup function to set.
def set_teardown_fn(teardown_fn: Callable[[], NoneType]) -> None:
85def set_teardown_fn(teardown_fn: Callable[[], None]) -> None:
86    """
87    Set a global teardown function to be called after each test.
88
89    :param teardown_fn: The teardown function to set.
90    """
91    global g_teardown_fn
92    g_teardown_fn = teardown_fn

Set a global teardown function to be called after each test.

Parameters
  • teardown_fn: The teardown function to set.
def add_test(func, *args, run_now: bool = False, **kwargs):
 95def add_test(func, *args, run_now: bool = False, **kwargs):
 96    """
 97    Register a test function to be run later or immediately.
 98
 99    :param func: The test function to register.
100    :param args: Positional arguments to pass to the test function.
101    :param run_now: If True, run the test immediately instead of registering it.
102    :param kwargs: Keyword arguments to pass to the test function.
103    """
104
105    global g_tests
106
107    logging.debug(f"Adding test: {func.__name__}, run_now={run_now}")
108    test_fn = TestFn(func, args, kwargs)
109    if run_now:
110        _run_single_test(test_fn)
111    else:
112        g_tests.append(test_fn)

Register a test function to be run later or immediately.

Parameters
  • func: The test function to register.
  • args: Positional arguments to pass to the test function.
  • run_now: If True, run the test immediately instead of registering it.
  • kwargs: Keyword arguments to pass to the test function.
def run_tests():
146def run_tests():
147    """
148    Run all registered tests.
149    """
150
151    global g_tests
152
153    logging.debug(f"Running {len(g_tests)} tests...")
154    for test_fn in g_tests:
155        _run_single_test(test_fn)
156        print()

Run all registered tests.

def clear_tests():
159def clear_tests():
160    """
161    Clear all registered tests.
162    """
163
164    global g_tests
165    logging.debug("Clearing all registered tests.")
166    g_tests.clear()

Clear all registered tests.

def assert_true(cond: bool, msg: str = '', negate: bool = False):
169def assert_true(cond: bool, msg: str = "", negate: bool = False):
170    """
171    Assert that a condition is true (or false if negate is True).
172
173    :param cond: The condition to check.
174    :param msg: An optional message to display with the assertion result.
175    :param negate: If True, assert that the condition is false.
176    """
177
178    global g_active_test
179    if g_active_test is None:
180        raise RuntimeError("No active test context for assertion.")
181
182    if cond != negate:
183        g_active_test.success()
184    else:
185        g_active_test.failure()
186
187    print_helper.print_assert(msg, cond != negate)

Assert that a condition is true (or false if negate is True).

Parameters
  • cond: The condition to check.
  • msg: An optional message to display with the assertion result.
  • negate: If True, assert that the condition is false.
def assert_false(cond: bool, msg: str = '', negate: bool = False):
189def assert_false(cond: bool, msg: str = "", negate: bool = False):
190    """
191    Assert that a condition is false (or true if negate is True).
192
193    :param cond: The condition to check.
194    :param msg: An optional message to display with the assertion result.
195    :param negate: If True, assert that the condition is true.
196    """
197    assert_true(not cond, msg, negate)

Assert that a condition is false (or true if negate is True).

Parameters
  • cond: The condition to check.
  • msg: An optional message to display with the assertion result.
  • negate: If True, assert that the condition is true.
def assert_eqf(a: float, b: float, tol: float, msg: str = '', negate: bool = False):
200def assert_eqf(a: float, b: float, tol: float, msg: str = "", negate: bool = False):
201    """
202    Assert that two floating-point numbers are equal within a tolerance.
203
204    :param a: The first floating-point number.
205    :param b: The second floating-point number.
206    :param tol: The tolerance within which the two numbers are considered equal.
207    :param msg: An optional message to display with the assertion result.
208    :param negate: If True, assert that the two numbers are not equal within the tolerance.
209    """
210    assert_true(abs(a - b) <= tol, msg, negate)

Assert that two floating-point numbers are equal within a tolerance.

Parameters
  • a: The first floating-point number.
  • b: The second floating-point number.
  • tol: The tolerance within which the two numbers are considered equal.
  • msg: An optional message to display with the assertion result.
  • negate: If True, assert that the two numbers are not equal within the tolerance.