Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,28 @@ def value(self):
elif self.name == ">=": return arg_vals[0] >= arg_vals[1]
return None # default

def get_bounds(self):
(lb1, ub1), (lb2, ub2) = get_bounds(self.args[0]), get_bounds(self.args[1])
if self.name == "==":
if lb1 == ub1 == lb2 == ub2: return (1,1) # equal domains, trivially true
if ub1 < lb2 or ub2 < lb1: return (0,0) # disjoint, trivially false
if self.name == "!=":
if ub1 < lb2 or ub2 < lb1: return (1,1) # disjoint, trivially true
if lb1 == ub1 == lb2 == ub2: return (0,0) # equal domains, trivially false
if self.name == "<=":
if ub1 <= lb2: return (1,1) # domain of lhs is leq domain of rhs
if lb1 > ub2: return (0,0) # domain of lhs is gt domain of rhs
if self.name == "<":
if ub1 < lb2: return (1,1) # domain of lhs is lt domain of rhs
if lb1 >= ub2: return (0,0) # domain of lhs is geq domain of rhs
if self.name == ">=":
if lb1 >= ub2: return (1,1) # domain of lhs is geq domain of rhs
if ub1 < lb2: return (0,0) # domain of lhs is lt domain of rhs
if self.name == ">":
if lb1 > ub2: return (1,1) # domain of lhs is gt domain of rhs
if ub1 <= lb2: return (0,0) # domain of lhs is leq domain of rhs
return (0,1)


class Operator(Expression):
"""
Expand Down
2 changes: 1 addition & 1 deletion cpmpy/expressions/globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def decompose(self):
decomp = [sum(self.args[:2]) == 1]
if len(self.args) > 2:
decomp = Xor([decomp,self.args[2:]]).decompose()[0]
return decomp, []
return cp.transformations.normalize.simplify_boolean(decomp), []

def value(self):
return sum(argvals(self.args)) % 2 == 1
Expand Down
11 changes: 10 additions & 1 deletion cpmpy/transformations/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..expressions.core import BoolVal, Expression, Comparison, Operator
from ..expressions.globalfunctions import GlobalFunction
from ..expressions.utils import eval_comparison, is_false_cst, is_true_cst, is_boolexpr, is_num, is_bool
from ..expressions.utils import eval_comparison, is_false_cst, is_true_cst, is_boolexpr, is_num, is_bool, get_bounds
from ..expressions.variables import NDVarArray, _BoolVarImpl
from ..exceptions import NotSupportedError
from ..expressions.globalconstraints import GlobalConstraint
Expand Down Expand Up @@ -169,6 +169,15 @@ def simplify_boolean(lst_of_expr, num_context=False):
elif isinstance(expr, Comparison):
lhs, rhs = simplify_boolean(expr.args, num_context=True)
name = expr.name

lb, ub = get_bounds(eval_comparison(name, lhs, rhs))
if lb == 0 == ub:
newlist.append(0 if num_context else BoolVal(False))
continue
if lb == 1 == ub:
newlist.append(1 if num_context else BoolVal(True))
continue

if is_num(lhs) and is_boolexpr(rhs): # flip arguments of comparison to reduct nb of cases
if name == "<": name = ">"
elif name == ">": name = "<"
Expand Down
1 change: 1 addition & 0 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def global_constraints(solver):
if name == "Xor":
yield Xor(BOOL_ARGS)
yield Xor(BOOL_ARGS + [True,False])
yield Xor([True, BOOL_ARGS[0]])
continue
elif name == "Inverse":
expr = cls(NUM_ARGS, [1,0,2])
Expand Down
26 changes: 25 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from cpmpy.expressions import *
from cpmpy.expressions.variables import NDVarArray
from cpmpy.expressions.core import Comparison, Operator, Expression
from cpmpy.expressions.utils import eval_comparison, get_bounds, argval
from cpmpy.expressions.utils import eval_comparison, get_bounds, argval, all_pairs


class TestComparison(unittest.TestCase):
def test_comps(self):
Expand Down Expand Up @@ -450,6 +451,29 @@ def test_bounds_unary(self):
self.assertGreaterEqual(val,lb)
self.assertLessEqual(val,ub)

def test_bounds_comparison(self):

x_00 = intvar(0,0, name="x00")
x_01 = intvar(0,1, name="x01")
x_12= intvar(1,2, name="x12")
x_23 = intvar(2,3, name="x23")

for x,y in all_pairs([0, x_00, x_01, x_12, x_23]):
for comp in ['==','!=','<=','<','>=','>']:
x_bounds = get_bounds(x)
y_bounds = get_bounds(y)

total_vals = len(range(x_bounds[0],x_bounds[1]+1)) * len(range(y_bounds[0],y_bounds[1]+1))

for expr in [Comparison(comp, x,y), Comparison(comp, y,x)]:
lb, ub = expr.get_bounds()

if lb == 0 == ub:
self.assertEqual(cp.Model(expr).solveAll(), 0)
elif lb == 1 == ub:
self.assertEqual(cp.Model(expr).solveAll(), total_vals)
else:
self.assertNotEqual(cp.Model(expr).solveAll(), total_vals)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use assertLess()?
What with assertMore(..., 0)?


def test_incomplete_func(self):
# element constraint
Expand Down
17 changes: 17 additions & 0 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,23 @@ def test_xor_with_constants(self):
self.assertFalse(cp.Model(cp.Xor([False, False])).solve())
self.assertFalse(cp.Model(cp.Xor([False, False, False])).solve())

def test_issue_620(self):
a = cp.boolvar()
b = cp.boolvar()
c = cp.boolvar()

model = cp.Model(cp.Xor([(cp.Xor([a, b, c])) <= True, ~((cp.Xor([a, b, c])) <= True)]))

self.assertTrue(model.solve(solver='ortools'))
if "minizinc" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='minizinc'))
if "z3" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='z3'))
if "choco" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='choco'))
if "gurobi" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='gurobi'))

def test_ite_with_constants(self):
x,y,z = cp.boolvar(shape=3)
expr = cp.IfThenElse(True, y, z)
Expand Down
Loading