Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions frontend/cs/r1cs/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,16 @@ func (builder *builder[E]) MulAcc(a, b, c frontend.Variable) frontend.Variable {
// results fits, _a is mutated without performing a new memalloc
builder.mbuf2 = builder.mbuf2[:0]
builder.add([]expr.LinearExpression[E]{_a, builder.mbuf1}, false, 0, &builder.mbuf2)
_a = _a[:0]
if len(builder.mbuf2) <= cap(_a) {

// if we can add the multiplication term to the accumulator LE (by having sufficient capacity)
// then we append directly into _a. However, _a can also be the hardcoded linear expressions corresponding
// to zero or one constant. Now, we we would append into those then we would modify the underlying slice
// thus modifying the constant themselves. This leads to undefined behaviour.
//
// So, in addition to only checking the capacity we also check that the underlying slices are different.
// to avoid using unsafe.Pointer, we check the address of the first elements.
if len(builder.mbuf2) <= cap(_a) && &(_a[0]) != &(builder.cstZero()[0]) && &(_a[0]) != &(builder.cstOne()[0]) {
_a = _a[:0]
// it fits, no mem alloc
_a = append(_a, builder.mbuf2...)
} else {
Expand Down
33 changes: 33 additions & 0 deletions frontend/cs/r1cs/r1cs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,36 @@ func TestSubSameNoConstraint(t *testing.T) {
t.Fatal("expected 0 constraints")
}
}

type overwriteZeroConstCircuit struct {
X frontend.Variable
Y frontend.Variable
}

func (c *overwriteZeroConstCircuit) Define(api frontend.API) error {
constVar := api.Mul(0, c.X) // this create a zero constant
constVar = api.MulAcc(constVar, c.X, c.Y) // due to bug the zero constant is overwritten
_ = constVar
constVar2 := api.Mul(0, c.Y) // this create another zero constant
api.AssertIsEqual(constVar2, 0)

return nil
}

func TestOverrideZeroConstant(t *testing.T) {
ccs, err := frontend.Compile(ecc.BN254.ScalarField(), NewBuilder, &overwriteZeroConstCircuit{})
if err != nil {
t.Fatal(err)
}
wit, err := frontend.NewWitness(&overwriteZeroConstCircuit{
X: 5,
Y: 10,
}, ecc.BN254.ScalarField())
if err != nil {
t.Fatal(err)
}
_, err = ccs.Solve(wit)
if err != nil {
t.Fatal(err)
}
}
Loading