Skip to content

Commit 60a3d13

Browse files
x/sqlbuilder/upserter: Directly assert in the SQL query if the value are the correct ones
1 parent b1c0f8d commit 60a3d13

File tree

1 file changed

+22
-56
lines changed

1 file changed

+22
-56
lines changed

x/sqlbuilder/upserter/execer.go

Lines changed: 22 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
package upserter
22

33
import (
4-
"bytes"
54
"context"
65
"database/sql/driver"
76
"fmt"
8-
"reflect"
9-
"time"
107

118
"github.com/upfluence/errors"
129
"github.com/upfluence/sql"
@@ -127,6 +124,12 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
127124
}
128125
}
129126

127+
var selectClauses = []sqlbuilder.Marker{oneMarker}
128+
129+
for _, m := range stmt.SetValues {
130+
selectClauses = append(selectClauses, &assertMarker{Marker: m})
131+
}
132+
130133
var (
131134
clauses = make([]sqlbuilder.PredicateClause, len(stmt.QueryValues))
132135

@@ -136,7 +139,7 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
136139
sfs: make([]string, len(stmt.SetValues)),
137140
ss: sqlbuilder.SelectStatement{
138141
Table: stmt.Table,
139-
SelectClauses: append([]sqlbuilder.Marker{oneMarker}, stmt.SetValues...),
142+
SelectClauses: selectClauses,
140143
},
141144
us: sqlbuilder.UpdateStatement{
142145
Table: stmt.Table,
@@ -199,49 +202,19 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
199202
return &e
200203
}
201204

202-
func cloneValue(v any) (any, error) {
203-
if dv, ok := v.(driver.Valuer); ok {
204-
vv, err := dv.Value()
205-
206-
if err != nil {
207-
return nil, err
208-
}
209-
210-
if vv != nil {
211-
v = vv
212-
}
213-
}
214-
215-
return reflect.New(reflect.TypeOf(v)).Interface(), nil
205+
type assertMarker struct {
206+
sqlbuilder.Marker
216207
}
217208

218-
func equalValues(x, y any) (bool, error) {
219-
if dy, ok := y.(driver.Valuer); ok {
220-
yy, err := dy.Value()
221-
222-
if err != nil {
223-
return false, err
224-
}
225-
226-
if yy != nil {
227-
y = yy
228-
}
229-
}
230-
231-
switch yy := y.(type) {
232-
case time.Time:
233-
if xx, ok := x.(time.Time); ok {
234-
return xx.Equal(yy), nil
235-
}
236-
case []byte:
237-
if xx, ok := x.([]byte); ok {
238-
return bytes.Equal(yy, xx), nil
239-
}
240-
default:
241-
return reflect.DeepEqual(x, y), nil
242-
}
209+
func (am *assertMarker) WriteTo(qw sqlbuilder.QueryWriter, vs map[string]interface{}) error {
210+
_, err := fmt.Fprintf(
211+
qw,
212+
"%s = %s",
213+
am.ToSQL(),
214+
qw.RedeemVariable(vs["assert_"+am.Binding()]),
215+
)
243216

244-
return false, nil
217+
return err
245218
}
246219

247220
func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Result, error) {
@@ -271,13 +244,11 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul
271244
return nil, sqlbuilder.ErrMissingKey{Key: f}
272245
}
273246

274-
var err error
247+
var val sql.NullBool
275248

276-
existing[f], err = cloneValue(v)
249+
qvs["assert_"+f] = v
277250

278-
if err != nil {
279-
return nil, err
280-
}
251+
existing[f] = &val
281252
}
282253

283254
if m := e.returningMarker; m != nil {
@@ -296,14 +267,9 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul
296267
pristine := true
297268

298269
for _, sf := range e.sfs {
299-
ok, err := equalValues(reflect.ValueOf(existing[sf]).Elem().Interface(), vs[sf])
300-
301-
if err != nil {
302-
return err
303-
}
304-
305-
if !ok {
270+
if val := existing[sf].(*sql.NullBool); !val.Bool {
306271
pristine = false
272+
307273
break
308274
}
309275
}

0 commit comments

Comments
 (0)