11package upserter
22
33import (
4- "bytes"
54 "context"
65 "database/sql/driver"
76 "fmt"
8- "reflect"
9- "time"
107
118 "github.com/upfluence/errors"
129 "github.com/upfluence/sql"
@@ -127,6 +124,12 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
127124 }
128125 }
129126
127+ var selectClauses = []sqlbuilder.Marker {oneMarker }
128+
129+ for _ , m := range stmt .SetValues {
130+ selectClauses = append (selectClauses , & assertMarker {Marker : m })
131+ }
132+
130133 var (
131134 clauses = make ([]sqlbuilder.PredicateClause , len (stmt .QueryValues ))
132135
@@ -136,7 +139,7 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
136139 sfs : make ([]string , len (stmt .SetValues )),
137140 ss : sqlbuilder.SelectStatement {
138141 Table : stmt .Table ,
139- SelectClauses : append ([]sqlbuilder. Marker { oneMarker }, stmt . SetValues ... ) ,
142+ SelectClauses : selectClauses ,
140143 },
141144 us : sqlbuilder.UpdateStatement {
142145 Table : stmt .Table ,
@@ -199,49 +202,19 @@ func newExecer(te txExecutor, stmt Statement) sqlbuilder.Execer {
199202 return & e
200203}
201204
202- func cloneValue (v any ) (any , error ) {
203- if dv , ok := v .(driver.Valuer ); ok {
204- vv , err := dv .Value ()
205-
206- if err != nil {
207- return nil , err
208- }
209-
210- if vv != nil {
211- v = vv
212- }
213- }
214-
215- return reflect .New (reflect .TypeOf (v )).Interface (), nil
205+ type assertMarker struct {
206+ sqlbuilder.Marker
216207}
217208
218- func equalValues (x , y any ) (bool , error ) {
219- if dy , ok := y .(driver.Valuer ); ok {
220- yy , err := dy .Value ()
221-
222- if err != nil {
223- return false , err
224- }
225-
226- if yy != nil {
227- y = yy
228- }
229- }
230-
231- switch yy := y .(type ) {
232- case time.Time :
233- if xx , ok := x .(time.Time ); ok {
234- return xx .Equal (yy ), nil
235- }
236- case []byte :
237- if xx , ok := x .([]byte ); ok {
238- return bytes .Equal (yy , xx ), nil
239- }
240- default :
241- return reflect .DeepEqual (x , y ), nil
242- }
209+ func (am * assertMarker ) WriteTo (qw sqlbuilder.QueryWriter , vs map [string ]interface {}) error {
210+ _ , err := fmt .Fprintf (
211+ qw ,
212+ "%s = %s" ,
213+ am .ToSQL (),
214+ qw .RedeemVariable (vs ["assert_" + am .Binding ()]),
215+ )
243216
244- return false , nil
217+ return err
245218}
246219
247220func (e * execer ) Exec (ctx context.Context , vs map [string ]interface {}) (sql.Result , error ) {
@@ -271,13 +244,11 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul
271244 return nil , sqlbuilder.ErrMissingKey {Key : f }
272245 }
273246
274- var err error
247+ var val sql. NullBool
275248
276- existing [ f ], err = cloneValue ( v )
249+ qvs [ "assert_" + f ] = v
277250
278- if err != nil {
279- return nil , err
280- }
251+ existing [f ] = & val
281252 }
282253
283254 if m := e .returningMarker ; m != nil {
@@ -296,14 +267,9 @@ func (e *execer) Exec(ctx context.Context, vs map[string]interface{}) (sql.Resul
296267 pristine := true
297268
298269 for _ , sf := range e .sfs {
299- ok , err := equalValues (reflect .ValueOf (existing [sf ]).Elem ().Interface (), vs [sf ])
300-
301- if err != nil {
302- return err
303- }
304-
305- if ! ok {
270+ if val := existing [sf ].(* sql.NullBool ); ! val .Bool {
306271 pristine = false
272+
307273 break
308274 }
309275 }
0 commit comments