Skip to content

Commit e1502a7

Browse files
authored
feat(scanner): support non-struct T for single-column queries (#10)
1 parent e399626 commit e1502a7

File tree

3 files changed

+161
-27
lines changed

3 files changed

+161
-27
lines changed

query.go

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"iter"
88
"reflect"
99
"sync"
10+
"time"
1011
)
1112

1213
type queryer interface {
@@ -84,29 +85,58 @@ type scanner interface {
8485
}
8586

8687
func scan[T any](s scanner, columns []string) (T, error) {
87-
var t T
88-
v := reflect.ValueOf(&t).Elem()
89-
if v.Kind() != reflect.Struct {
90-
panic("queries: T must be a struct")
88+
if len(columns) == 0 {
89+
panic("queries: no columns specified") // valid in PostgreSQL (for some reason).
9190
}
9291

93-
indexes := parseStruct(v.Type())
92+
var t T
93+
v := reflect.ValueOf(&t).Elem()
9494
args := make([]any, len(columns))
9595

96-
for i, column := range columns {
97-
idx, ok := indexes[column]
98-
if !ok {
99-
panic(fmt.Sprintf("queries: no field for column %q", column))
96+
switch {
97+
case scannable(v):
98+
if len(columns) > 1 {
99+
panic("queries: T must be a struct if len(columns) > 1")
100+
}
101+
args[0] = v.Addr().Interface()
102+
case v.Kind() == reflect.Struct:
103+
indexes := parseStruct(v.Type())
104+
for i, column := range columns {
105+
idx, ok := indexes[column]
106+
if !ok {
107+
panic(fmt.Sprintf("queries: no field for column %q", column))
108+
}
109+
args[i] = v.Field(idx).Addr().Interface()
100110
}
101-
args[i] = v.Field(idx).Addr().Interface()
111+
default:
112+
panic(fmt.Sprintf("queries: unsupported T %T", t))
102113
}
114+
103115
if err := s.Scan(args...); err != nil {
104116
return zero[T](), err
105117
}
106118

107119
return t, nil
108120
}
109121

122+
func scannable(v reflect.Value) bool {
123+
switch v.Kind() {
124+
case reflect.Bool,
125+
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
126+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
127+
reflect.Float32, reflect.Float64,
128+
reflect.String:
129+
return true
130+
}
131+
if v.Type() == reflect.TypeFor[time.Time]() {
132+
return true
133+
}
134+
if v.Addr().Type().Implements(reflect.TypeFor[sql.Scanner]()) {
135+
return true
136+
}
137+
return false
138+
}
139+
110140
var cache sync.Map // map[reflect.Type]map[string]int
111141

112142
// parseStruct parses the given struct type and returns a map of column names to field indexes.

query_test.go

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,133 @@
11
package queries
22

33
import (
4+
"database/sql"
45
"errors"
56
"reflect"
67
"testing"
8+
"time"
79

810
"go-simpler.org/queries/internal/assert"
911
. "go-simpler.org/queries/internal/assert/EF"
1012
)
1113

1214
func Test_scan(t *testing.T) {
13-
t.Run("non-struct T", func(t *testing.T) {
15+
t.Run("no columns", func(t *testing.T) {
1416
fn := func() { _, _ = scan[int](nil, nil) }
15-
assert.Panics[E](t, fn, "queries: T must be a struct")
17+
assert.Panics[E](t, fn, "queries: no columns specified")
18+
})
19+
20+
t.Run("unsupported T", func(t *testing.T) {
21+
columns := []string{"foo", "bar"}
22+
23+
fn := func() { _, _ = scan[complex64](nil, columns) }
24+
assert.Panics[E](t, fn, "queries: unsupported T complex64")
25+
})
26+
27+
t.Run("non-struct T with len(columns) > 1", func(t *testing.T) {
28+
columns := []string{"foo", "bar"}
29+
30+
fn := func() { _, _ = scan[int](nil, columns) }
31+
assert.Panics[E](t, fn, "queries: T must be a struct if len(columns) > 1")
1632
})
1733

1834
t.Run("empty tag", func(t *testing.T) {
35+
columns := []string{"foo", "bar"}
36+
1937
type row struct {
20-
Foo int `sql:""`
38+
Foo int `sql:"foo"`
39+
Bar string `sql:""`
2140
}
22-
fn := func() { _, _ = scan[row](nil, nil) }
23-
assert.Panics[E](t, fn, "queries: field Foo has an empty `sql` tag")
41+
fn := func() { _, _ = scan[row](nil, columns) }
42+
assert.Panics[E](t, fn, "queries: field Bar has an empty `sql` tag")
2443
})
2544

2645
t.Run("missing field", func(t *testing.T) {
46+
columns := []string{"foo", "bar"}
47+
2748
type row struct {
2849
Foo int `sql:"foo"`
2950
Bar string
3051
}
31-
fn := func() { _, _ = scan[row](nil, []string{"foo", "bar"}) }
52+
fn := func() { _, _ = scan[row](nil, columns) }
3253
assert.Panics[E](t, fn, `queries: no field for column "bar"`)
3354
})
3455

3556
t.Run("scan error", func(t *testing.T) {
36-
columns := []string{"foo"}
57+
columns := []string{"foo", "bar"}
3758
s := mockScanner{err: errors.New("an error")}
3859

3960
type row struct {
40-
Foo int `sql:"foo"`
61+
Foo int `sql:"foo"`
62+
Bar string `sql:"bar"`
4163
}
4264
_, err := scan[row](&s, columns)
4365
assert.IsErr[E](t, err, s.err)
4466
})
4567

46-
t.Run("ok", func(t *testing.T) {
68+
t.Run("struct T", func(t *testing.T) {
4769
columns := []string{"foo", "bar"}
48-
s := mockScanner{values: []any{1, "A"}}
70+
s := mockScanner{values: []any{1, "test"}}
4971

5072
type row struct {
5173
Foo int `sql:"foo"`
5274
Bar string `sql:"bar"`
5375
unexported bool
5476
}
55-
r, err := scan[row](&s, columns)
77+
v, err := scan[row](&s, columns)
78+
assert.NoErr[F](t, err)
79+
assert.Equal[E](t, v.Foo, 1)
80+
assert.Equal[E](t, v.Bar, "test")
81+
assert.Equal[E](t, v.unexported, false)
82+
})
83+
84+
t.Run("struct T with len(columns) == 1", func(t *testing.T) {
85+
columns := []string{"foo"}
86+
s := mockScanner{values: []any{1}}
87+
88+
type row struct {
89+
Foo int `sql:"foo"`
90+
}
91+
v, err := scan[row](&s, columns)
92+
assert.NoErr[F](t, err)
93+
assert.Equal[E](t, v.Foo, 1)
94+
})
95+
96+
t.Run("non-struct T with len(columns) == 1", func(t *testing.T) {
97+
columns := []string{"foo"}
98+
99+
tests := []struct {
100+
scan func(scanner) (any, error)
101+
value any
102+
}{
103+
{func(s scanner) (any, error) { return scan[bool](s, columns) }, true},
104+
{func(s scanner) (any, error) { return scan[int](s, columns) }, int(-1)},
105+
{func(s scanner) (any, error) { return scan[int8](s, columns) }, int8(-8)},
106+
{func(s scanner) (any, error) { return scan[int16](s, columns) }, int16(-16)},
107+
{func(s scanner) (any, error) { return scan[int32](s, columns) }, int32(-32)},
108+
{func(s scanner) (any, error) { return scan[int64](s, columns) }, int64(-64)},
109+
{func(s scanner) (any, error) { return scan[uint](s, columns) }, uint(1)},
110+
{func(s scanner) (any, error) { return scan[uint8](s, columns) }, uint8(8)},
111+
{func(s scanner) (any, error) { return scan[uint16](s, columns) }, uint16(16)},
112+
{func(s scanner) (any, error) { return scan[uint32](s, columns) }, uint32(32)},
113+
{func(s scanner) (any, error) { return scan[uint64](s, columns) }, uint64(64)},
114+
{func(s scanner) (any, error) { return scan[float32](s, columns) }, float32(0.32)},
115+
{func(s scanner) (any, error) { return scan[float64](s, columns) }, float64(0.64)},
116+
{func(s scanner) (any, error) { return scan[string](s, columns) }, "test"},
117+
{func(s scanner) (any, error) { return scan[time.Time](s, columns) }, time.Now()},
118+
}
119+
for _, tt := range tests {
120+
s := mockScanner{values: []any{tt.value}}
121+
v, err := tt.scan(&s)
122+
assert.NoErr[F](t, err)
123+
assert.Equal[E](t, v, tt.value)
124+
}
125+
126+
// sql.Scanner implementation:
127+
s := mockScanner{values: []any{"test"}}
128+
v, err := scan[sql.Null[string]](&s, columns)
56129
assert.NoErr[F](t, err)
57-
assert.Equal[E](t, r.Foo, 1)
58-
assert.Equal[E](t, r.Bar, "A")
59-
assert.Equal[E](t, r.unexported, false)
130+
assert.Equal[E](t, v, sql.Null[string]{V: "test", Valid: true})
60131
})
61132
}
62133

@@ -70,8 +141,14 @@ func (s *mockScanner) Scan(dst ...any) error {
70141
return s.err
71142
}
72143
for i := range dst {
73-
v := reflect.ValueOf(s.values[i])
74-
reflect.ValueOf(dst[i]).Elem().Set(v)
144+
if sc, ok := dst[i].(sql.Scanner); ok {
145+
if err := sc.Scan(s.values[i]); err != nil {
146+
return err
147+
}
148+
} else {
149+
v := reflect.ValueOf(s.values[i])
150+
reflect.ValueOf(dst[i]).Elem().Set(v)
151+
}
75152
}
76153
return nil
77154
}

tests/integration_test.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"iter"
78
"testing"
89
"time"
910

@@ -38,13 +39,18 @@ func TestIntegration(t *testing.T) {
3839
ctx := t.Context()
3940

4041
for name, database := range DBs {
42+
var execCalls int
43+
var queryCalls int
44+
4145
interceptor := queries.Interceptor{
4246
Driver: database.driver,
4347
ExecContext: func(ctx context.Context, query string, args []driver.NamedValue, execer driver.ExecerContext) (driver.Result, error) {
48+
execCalls++
4449
t.Logf("[%s] ExecContext: %s %v", name, query, namedToAny(args))
4550
return execer.ExecContext(ctx, query, args)
4651
},
4752
QueryContext: func(ctx context.Context, query string, args []driver.NamedValue, queryer driver.QueryerContext) (driver.Rows, error) {
53+
queryCalls++
4854
t.Logf("[%s] QueryContext: %s %v", name, query, namedToAny(args))
4955
return queryer.QueryContext(ctx, query, args)
5056
},
@@ -78,9 +84,17 @@ func TestIntegration(t *testing.T) {
7884
for _, queryer := range []interface {
7985
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
8086
}{db, tx} {
81-
_, err := queries.QueryRow[User](ctx, queryer, "SELECT id, name, created_at FROM users WHERE id = 0")
87+
_, err := queries.QueryRow[string](ctx, queryer, "SELECT name FROM users WHERE id = 0")
8288
assert.IsErr[E](t, err, sql.ErrNoRows)
8389

90+
name, err := queries.QueryRow[string](ctx, queryer, "SELECT name FROM users WHERE id = 1")
91+
assert.NoErr[F](t, err)
92+
assert.Equal[E](t, name, TableUsers[0].Name)
93+
94+
names, err := collect(queries.Query[string](ctx, queryer, "SELECT name FROM users"))
95+
assert.NoErr[F](t, err)
96+
assert.Equal[E](t, names, []string{TableUsers[0].Name, TableUsers[1].Name, TableUsers[2].Name})
97+
8498
user, err := queries.QueryRow[User](ctx, queryer, "SELECT id, name, created_at FROM users WHERE id = 1")
8599
assert.NoErr[F](t, err)
86100
assert.Equal[E](t, user.ID, TableUsers[0].ID)
@@ -96,6 +110,8 @@ func TestIntegration(t *testing.T) {
96110
}
97111

98112
assert.NoErr[F](t, tx.Commit())
113+
assert.Equal[E](t, execCalls, 2)
114+
assert.Equal[E](t, queryCalls, 5*2)
99115
}
100116
}
101117

@@ -107,6 +123,17 @@ func namedToAny(values []driver.NamedValue) []any {
107123
return args
108124
}
109125

126+
func collect[T any](seq iter.Seq2[T, error]) ([]T, error) {
127+
var ts []T
128+
for t, err := range seq {
129+
if err != nil {
130+
return nil, err
131+
}
132+
ts = append(ts, t)
133+
}
134+
return ts, nil
135+
}
136+
110137
func migrate(ctx context.Context, db *sql.DB) error {
111138
type migration struct {
112139
query string

0 commit comments

Comments
 (0)