Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
import cpmpy as cp

def pytest_addoption(parser):
"""
Adds cli arguments to the pytest command
"""
parser.addoption(
"--solver", type=str, action="store", default=None, help="Only run the tests on this particular solver."
)

@pytest.fixture
def solver(request):
"""
Limit tests to a specific solver.

By providing the cli argument `--solver=<SOLVER_NAME>`, two things will happen:
- non-solver-specific tests which make a `.solve()` call will now use `SOLVER_NAME` as backend (instead of the default OR-Tools)
- solver-specific tests, like the ones produced through `_generate_inputs`, will be filtered if they don't match `SOLVER_NAME`

By not providing a value for ``--solver`, the default behaviour will be to run non-solver-specific on the default solver (OR-Tools),
and to run all solver-specific tests for which the solver has been installed on the system.
"""
request.cls.solver = request.config.getoption("--solver")
return request.config.getoption("--solver")

def pytest_configure(config):
# Register custom marker for documentation and linting
config.addinivalue_line(
"markers",
"requires_solver(name): mark test as requiring a specific solver", # to filter tests when required solver is not installed
)
config.addinivalue_line(
"markers",
"requires_dependency(name): mark test as requiring a specific dependency", # to filter tests when required solver is not installed
)


def pytest_collection_modifyitems(config, items):
"""
Centrally apply filters and skips to test targets.

For now, only solver-based filtering gets applied.
"""
cmd_solver = config.getoption("--solver") # get cli `--solver`` arg

filtered = []
for item in items:
required_solver_marker = item.get_closest_marker("requires_solver")
required_dependency_marker = item.get_closest_marker("requires_dependency")

# --------------------------------- Dependency filtering --------------------------------- #
if required_dependency_marker:
if not all(importlib.util.find_spec(dependency) is not None for dependency in required_dependency_marker.args):
skip = pytest.mark.skip(reason=f"Dependency {required_dependency_marker.args} not installed")
item.add_marker(skip)
continue

# --------------------------------- Solver filtering --------------------------------- #
if required_solver_marker:
required_solvers = required_solver_marker.args

# --------------------------------- Filtering -------------------------------- #

# when a solver is specified on the command line,
# only run solver-specific tests that require that solver
if cmd_solver:
if cmd_solver in required_solvers:
filtered.append(item)
else:
continue
# instance has survived filtering
else:
filtered.append(item)

# --------------------------------- Skipping --------------------------------- #

# skip test if the required solver is not installed
if not {k:v for k,v in cp.SolverLookup.base_solvers()}[required_solvers[0]].supported():
skip = pytest.mark.skip(reason=f"Solver {cmd_solver} not installed")
item.add_marker(skip)

continue # skip rest of the logic for this test

# ------------------------------ More filtering ------------------------------ #

# Only filter tests that are parameterized with a 'solver' (through `_generate_inputs`)
if hasattr(item, "callspec"):
if cmd_solver:
if "solver" in item.callspec.params:
solver = item.callspec.params["solver"]
if solver == cmd_solver:
filtered.append(item)
if "solver_name" in item.callspec.params:
solver = item.callspec.params["solver_name"]
if solver == cmd_solver:
filtered.append(item)
else:
filtered.append(item)
else:
# keep non-parametrized tests
filtered.append(item)

items[:] = filtered
19 changes: 10 additions & 9 deletions tests/test_builtins.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,51 @@
import unittest
import pytest

import cpmpy as cp
from cpmpy.expressions.python_builtins import all as cpm_all, any as cpm_any
from cpmpy.exceptions import CPMpyException

iv = cp.intvar(-8, 8, shape=5)


@pytest.mark.usefixtures("solver")
class TestBuiltin(unittest.TestCase):

def test_max(self):
constraints = [cp.max(iv) + 9 <= 8]
model = cp.Model(constraints)
self.assertTrue(model.solve())
self.assertTrue(model.solve(solver=self.solver))
self.assertTrue(cp.max(iv.value()) <= -1)

model = cp.Model(cp.max(iv).decompose_comparison('!=', 4))
self.assertTrue(model.solve())
self.assertTrue(model.solve(solver=self.solver))
self.assertNotEqual(str(cp.max(iv.value())), '4')

def test_min(self):
constraints = [cp.min(iv) + 9 == 8]
model = cp.Model(constraints)
self.assertTrue(model.solve())
self.assertTrue(model.solve(solver=self.solver))
self.assertEqual(str(cp.min(iv.value())), '-1')

model = cp.Model(cp.min(iv).decompose_comparison('==', 4))
self.assertTrue(model.solve())
self.assertTrue(model.solve(solver=self.solver))
self.assertEqual(str(cp.min(iv.value())), '4')

def test_abs(self):
constraints = [cp.abs(iv[0]) + 9 <= 8]
model = cp.Model(constraints)
self.assertFalse(model.solve())
self.assertFalse(model.solve(solver=self.solver))

#with list
constraints = [cp.abs(iv+2) <= 8, iv < 0]
model = cp.Model(constraints)
self.assertTrue(model.solve())
self.assertTrue(model.solve(solver=self.solver))

constraints = [cp.abs([iv[0], iv[2], iv[1], -8]) <= 8, iv < 0]
model = cp.Model(constraints)
self.assertTrue(model.solve())
self.assertTrue(model.solve(solver=self.solver))

model = cp.Model(cp.abs(iv[0]).decompose_comparison('!=', 4))
self.assertTrue(model.solve())
self.assertTrue(model.solve(solver=self.solver))
self.assertNotEqual(str(cp.abs(iv[0].value())), '4')

# Boolean builtins
Expand Down
28 changes: 9 additions & 19 deletions tests/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from cpmpy import *
from cpmpy.solvers import CPM_gurobi, CPM_pysat, CPM_minizinc, CPM_pysdd, CPM_z3, CPM_exact, CPM_choco, CPM_hexaly


@pytest.mark.requires_solver("ortools")
class TestDirectORTools(unittest.TestCase):

def test_direct_automaton(self):
Expand All @@ -26,8 +26,7 @@ def test_direct_automaton(self):

self.assertEqual(model.solveAll(), 6)


@pytest.mark.skipif(not CPM_exact.supported(), reason="Exact not installed")
@pytest.mark.requires_solver("exact")
class TestDirectExact(unittest.TestCase):

def test_direct_left_reif(self):
Expand All @@ -40,9 +39,7 @@ def test_direct_left_reif(self):
print(model)
self.assertEqual(model.solveAll(), 3)


@pytest.mark.skipif(not CPM_pysat.supported(),
reason="PySAT not installed")
@pytest.mark.requires_solver("pysat")
class TestDirectPySAT(unittest.TestCase):

def test_direct_clause(self):
Expand All @@ -55,8 +52,7 @@ def test_direct_clause(self):
self.assertTrue(model.solve())
self.assertTrue(x.value() or y.value())

@pytest.mark.skipif(not CPM_pysdd.supported(),
reason="PySDD not installed")
@pytest.mark.requires_solver("pysdd")
class TestDirectPySDD(unittest.TestCase):

def test_direct_clause(self):
Expand All @@ -69,8 +65,7 @@ def test_direct_clause(self):
self.assertTrue(model.solve())
self.assertTrue(x.value() or y.value())

@pytest.mark.skipif(not CPM_z3.supported(),
reason="Z3py not installed")
@pytest.mark.requires_solver("z3")
class TestDirectZ3(unittest.TestCase):

def test_direct_clause(self):
Expand All @@ -83,8 +78,7 @@ def test_direct_clause(self):
self.assertTrue(model.solve())
self.assertTrue(AllDifferent(iv).value())

@pytest.mark.skipif(not CPM_minizinc.supported(),
reason="MinZinc not installed")
@pytest.mark.requires_solver("minizinc")
class TestDirectMiniZinc(unittest.TestCase):

def test_direct_clause(self):
Expand All @@ -102,9 +96,7 @@ def test_direct_clause(self):
self.assertTrue(model.solve())
self.assertTrue(AllDifferent(iv).value())


@pytest.mark.skipif(not CPM_gurobi.supported(),
reason="Gurobi not installed")
@pytest.mark.requires_solver("gurobi")
class TestDirectGurobi(unittest.TestCase):

def test_direct_poly(self):
Expand All @@ -127,8 +119,7 @@ def test_direct_poly(self):

self.assertEqual(y.value(), poly_val)

@pytest.mark.skipif(not CPM_choco.supported(),
reason="pychoco not installed")
@pytest.mark.requires_solver("choco")
class TestDirectChoco(unittest.TestCase):

def test_direct_global(self):
Expand All @@ -142,8 +133,7 @@ def test_direct_global(self):
self.assertFalse(model.solve())


@pytest.mark.skipif(not CPM_hexaly.supported(),
reason="hexaly is not installed")
@pytest.mark.requires_solver("hexaly")
class TestDirectHexaly(unittest.TestCase):

def test_direct_distance(self):
Expand Down
Loading
Loading