Skip to content

Commit 6dbb4fa

Browse files
authored
topdown: Handling default functions in Partial Eval (#7499)
Making Partial Eval (PE) respect default functions. Before this fix, Rego functions with declared default values weren't respected by PE, and the default declaration was omitted from generated support modules. Fixes: #7220 Signed-off-by: Johan Fylling <[email protected]>
1 parent 5a62130 commit 6dbb4fa

File tree

2 files changed

+171
-10
lines changed

2 files changed

+171
-10
lines changed

Diff for: v1/topdown/eval.go

+39-10
Original file line numberDiff line numberDiff line change
@@ -2019,15 +2019,36 @@ func (e evalFunc) eval(iter unifyIterator) error {
20192019
return e.e.saveCall(argCount, e.terms, iter)
20202020
}
20212021

2022-
if e.e.partial() && (e.e.inliningControl.shallow || e.e.inliningControl.Disabled(e.ref, false)) {
2023-
// check if the function definitions, or any of the arguments
2024-
// contain something unknown
2025-
unknown := e.e.unknown(e.ref, e.e.bindings)
2026-
for i := 1; !unknown && i <= argCount; i++ {
2027-
unknown = e.e.unknown(e.terms[i], e.e.bindings)
2022+
if e.e.partial() {
2023+
var mustGenerateSupport bool
2024+
2025+
if defRule := e.ir.Default; defRule != nil {
2026+
// The presence of a default func might force us to generate support
2027+
if len(defRule.Head.Args) == len(e.terms)-1 {
2028+
// The function is called without collecting the result in an output term,
2029+
// therefore any successful evaluation of the function is of interest, including the default value ...
2030+
if ret := defRule.Head.Value; ret == nil || !ret.Equal(ast.InternedBooleanTerm(false)) {
2031+
// ... unless the default value is false,
2032+
mustGenerateSupport = true
2033+
}
2034+
} else {
2035+
// The function is called with an output term, therefore any successful evaluation of the function is of interest.
2036+
// NOTE: Because of how the compiler rewrites function calls, we can't know if the result value is compared
2037+
// to a constant value, so we can't be as clever as we are for rules.
2038+
mustGenerateSupport = true
2039+
}
20282040
}
2029-
if unknown {
2030-
return e.partialEvalSupport(argCount, iter)
2041+
2042+
if mustGenerateSupport || e.e.inliningControl.shallow || e.e.inliningControl.Disabled(e.ref, false) {
2043+
// check if the function definitions, or any of the arguments
2044+
// contain something unknown
2045+
unknown := e.e.unknown(e.ref, e.e.bindings)
2046+
for i := 1; !unknown && i <= argCount; i++ {
2047+
unknown = e.e.unknown(e.terms[i], e.e.bindings)
2048+
}
2049+
if unknown {
2050+
return e.partialEvalSupport(argCount, iter)
2051+
}
20312052
}
20322053
}
20332054

@@ -2226,6 +2247,13 @@ func (e evalFunc) partialEvalSupport(declArgsLen int, iter unifyIterator) error
22262247
return err
22272248
}
22282249
}
2250+
2251+
if e.ir.Default != nil {
2252+
err := e.partialEvalSupportRule(e.ir.Default, path)
2253+
if err != nil {
2254+
return err
2255+
}
2256+
}
22292257
}
22302258

22312259
if !e.e.saveSupport.Exists(path) { // we haven't saved anything, nothing to call
@@ -2274,8 +2302,9 @@ func (e evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error {
22742302
}
22752303

22762304
e.e.saveSupport.Insert(path, &ast.Rule{
2277-
Head: head,
2278-
Body: plugged,
2305+
Head: head,
2306+
Body: plugged,
2307+
Default: rule.Default,
22792308
})
22802309
}
22812310
child.traceRedo(rule)

Diff for: v1/topdown/topdown_partial_test.go

+132
Original file line numberDiff line numberDiff line change
@@ -4068,6 +4068,138 @@ func TestTopDownPartialEval(t *testing.T) {
40684068
},
40694069
wantQueries: []string{""}, // unconditional true
40704070
},
4071+
4072+
{
4073+
note: "default function, result not collected (non-false default value)",
4074+
query: "data.test.p = true",
4075+
modules: []string{`package test
4076+
default f(x) := true # return true if x.size is undefined
4077+
f(x) if {
4078+
x.size < 100
4079+
}
4080+
p if {
4081+
f(input.x)
4082+
}
4083+
`},
4084+
wantQueries: []string{"data.partial.test.f(input.x)"},
4085+
wantSupport: []string{
4086+
`package partial.test
4087+
4088+
default f(__local0__3) = true
4089+
f(__local1__2) = true if { __local2__2 = __local1__2.size; lt(__local2__2, 100) }`,
4090+
},
4091+
},
4092+
{
4093+
note: "default function, result not collected (false default value)",
4094+
query: "data.test.p = true",
4095+
modules: []string{`package test
4096+
default f(x) := false
4097+
f(x) if {
4098+
x.size < 100
4099+
}
4100+
p if {
4101+
f(input.x)
4102+
}
4103+
`},
4104+
wantQueries: []string{"lt(input.x.size, 100)"},
4105+
},
4106+
{
4107+
note: "default function, result comparison (same as default)",
4108+
query: "data.test.p = true",
4109+
modules: []string{`package test
4110+
default f(x) := true # return true if x.size is undefined
4111+
f(x) if {
4112+
x.size < 100
4113+
}
4114+
p if {
4115+
f(input.x) == true
4116+
}
4117+
`},
4118+
wantQueries: []string{"data.partial.test.f(input.x, true)"},
4119+
wantSupport: []string{
4120+
`package partial.test
4121+
4122+
default f(__local0__3) = true
4123+
f(__local1__2) = true if { __local3__2 = __local1__2.size; lt(__local3__2, 100) }`,
4124+
},
4125+
},
4126+
{
4127+
note: "default function, result comparison (not same as default)",
4128+
query: "data.test.p = true",
4129+
modules: []string{`package test
4130+
default f(x) := true # return true if x.size is undefined
4131+
f(x) := y if {
4132+
y := x.size < 100
4133+
}
4134+
p if {
4135+
f(input.x) == false
4136+
}
4137+
`},
4138+
wantQueries: []string{"data.partial.test.f(input.x, false)"},
4139+
wantSupport: []string{
4140+
`package partial.test
4141+
4142+
default f(__local0__3) = true
4143+
f(__local1__2) = __local2__2 if { __local5__2 = __local1__2.size; lt(__local5__2, 100, __local3__2); __local2__2 = __local3__2 }`,
4144+
},
4145+
},
4146+
{
4147+
note: "default function, saved result",
4148+
query: "data.test.p = x",
4149+
modules: []string{`package test
4150+
default f(x) := true # return true if x.size is undefined
4151+
f(x) if {
4152+
x.size < 100
4153+
}
4154+
p := x if {
4155+
x := f(input.x)
4156+
}
4157+
`},
4158+
wantQueries: []string{"data.partial.test.f(input.x, x)"},
4159+
wantSupport: []string{
4160+
`package partial.test
4161+
4162+
default f(__local0__3) = true
4163+
f(__local1__2) = true if { __local4__2 = __local1__2.size; lt(__local4__2, 100) }`,
4164+
},
4165+
},
4166+
{
4167+
// This test case is redundant, but serves as a counter example to the test above.
4168+
// Inlining can happen as there is no default function to consider
4169+
note: "default function (no default)",
4170+
query: "data.test.p = true",
4171+
modules: []string{`package test
4172+
f(x) if {
4173+
x.size < 100
4174+
}
4175+
p if {
4176+
f(input)
4177+
}
4178+
`},
4179+
wantQueries: []string{"lt(input.size, 100)"},
4180+
},
4181+
{
4182+
note: "default function, shallow inlining",
4183+
query: "data.test.p = true",
4184+
modules: []string{`package test
4185+
default f(x) := true # return true if x.size is undefined
4186+
f(x) if {
4187+
x.size < 100
4188+
}
4189+
p if {
4190+
f(input)
4191+
}
4192+
`},
4193+
shallow: true,
4194+
wantQueries: []string{"data.partial.test.p = true"},
4195+
wantSupport: []string{
4196+
`package partial.test
4197+
4198+
p = true if { __local3__1 = input; data.partial.test.f(__local3__1) }
4199+
default f(__local0__3) = true
4200+
f(__local1__2) = true if { __local2__2 = __local1__2.size; lt(__local2__2, 100) }`,
4201+
},
4202+
},
40714203
}
40724204

40734205
ctx := context.Background()

0 commit comments

Comments
 (0)