Skip to content

Commit 6a59e6c

Browse files
authored
feat(scanner): add docs for Query/QueryRow, implement Collect (#11)
1 parent e1502a7 commit 6a59e6c

File tree

5 files changed

+86
-21
lines changed

5 files changed

+86
-21
lines changed

README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,18 @@ type User struct {
5959
Name string `sql:"name"`
6060
}
6161

62+
// single column, single row:
63+
name, _ := queries.QueryRow[string](ctx, db, "SELECT name FROM users WHERE id = 1")
64+
65+
// single column, multiple rows:
66+
names, _ := queries.Collect(queries.Query[string](ctx, db, "SELECT name FROM users"))
67+
68+
// multiple columns, single row:
69+
user, _ := queries.QueryRow[User](ctx, db, "SELECT id, name FROM users WHERE id = 1")
70+
71+
// multiple columns, multiple rows:
6272
for user, _ := range queries.Query[User](ctx, db, "SELECT id, name FROM users") {
63-
// user.ID, user.Name
73+
// ...
6474
}
6575
```
6676

@@ -98,7 +108,6 @@ Integration tests cover the following databases and drivers:
98108

99109
## 🚧 TODOs
100110

101-
- Add missing documentation.
102111
- Add more tests for different databases and drivers. See https://go.dev/wiki/SQLDrivers.
103112
- Add examples for tested databases and drivers.
104113
- Add benchmarks.

builder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type Builder struct {
2525
// IMPORTANT: to avoid SQL injections, make sure to pass arguments from user input with placeholder verbs.
2626
// Always test your queries.
2727
//
28-
// Placeholder verbs to database placeholders:
28+
// Placeholder verbs map to the following database placeholders:
2929
// - MySQL, SQLite: %? -> ?
3030
// - PostgreSQL: %$ -> $N
3131
// - MSSQL: %@ -> @pN

query.go

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,40 @@ import (
1010
"time"
1111
)
1212

13-
type queryer interface {
13+
// Queryer is an interface implemented by [sql.DB] and [sql.Tx].
14+
type Queryer interface {
1415
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
1516
}
1617

17-
// TODO: document me.
18-
func Query[T any](ctx context.Context, q queryer, query string, args ...any) iter.Seq2[T, error] {
18+
// Query executes a query that returns rows, scans each row into a T, and returns an iterator over the Ts.
19+
// If an error occurs, the iterator yields it as the second value, and the caller should then stop the iteration.
20+
// [Queryer] can be either [sql.DB] or [sql.Tx], the rest of the arguments are passed directly to [Queryer.QueryContext].
21+
// Query fully manages the lifecycle of the [sql.Rows] returned by [Queryer.QueryContext], so the caller does not have to.
22+
//
23+
// The following Ts are supported:
24+
// - int (any kind)
25+
// - uint (any kind)
26+
// - float (any kind)
27+
// - bool
28+
// - string
29+
// - time.Time
30+
// - [sql.Scanner] (implemented by [sql.Null] types)
31+
// - any struct
32+
//
33+
// See the [sql.Rows.Scan] documentation for the scanning rules.
34+
// If the query has multiple columns, T must be a struct, other types can only be used for single-column queries.
35+
// The fields of a struct T must have the `sql:"COLUMN"` tag, where COLUMN is the name of the corresponding column in the query.
36+
// Unexported and untagged fields are ignored.
37+
//
38+
// Query panics if:
39+
// - The query has no columns.
40+
// - A non-struct T is specified with a multi-column query.
41+
// - The specified struct T has no field for one of the query columns.
42+
// - An unsupported T is specified.
43+
// - One of the fields in a struct T has an empty `sql` tag.
44+
//
45+
// If the caller prefers the result to be a slice rather than an iterator, Query can be combined with [Collect].
46+
func Query[T any](ctx context.Context, q Queryer, query string, args ...any) iter.Seq2[T, error] {
1947
return func(yield func(T, error) bool) {
2048
rows, err := q.QueryContext(ctx, query, args...)
2149
if err != nil {
@@ -47,8 +75,12 @@ func Query[T any](ctx context.Context, q queryer, query string, args ...any) ite
4775
}
4876
}
4977

50-
// TODO: document me.
51-
func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) (T, error) {
78+
// QueryRow is a [Query] variant for queries that are expected to return at most one row,
79+
// so instead of an iterator, it returns a single T.
80+
// Like [sql.DB.QueryRow], QueryRow returns [sql.ErrNoRows] if the query selects no rows,
81+
// otherwise it scans the first row and discards the rest.
82+
// See the [Query] documentation for details on supported Ts.
83+
func QueryRow[T any](ctx context.Context, q Queryer, query string, args ...any) (T, error) {
5284
rows, err := q.QueryContext(ctx, query, args...)
5385
if err != nil {
5486
return zero[T](), err
@@ -78,6 +110,19 @@ func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any)
78110
return t, nil
79111
}
80112

113+
// Collect is a [slices.Collect] variant that collects values from an iter.Seq2[T, error].
114+
// If an error occurs during the collection, Collect stops the iteration and returns the error.
115+
func Collect[T any](seq iter.Seq2[T, error]) ([]T, error) {
116+
var ts []T
117+
for t, err := range seq {
118+
if err != nil {
119+
return nil, err
120+
}
121+
ts = append(ts, t)
122+
}
123+
return ts, nil
124+
}
125+
81126
func zero[T any]() (t T) { return t }
82127

83128
type scanner interface {

query_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,37 @@ package queries
33
import (
44
"database/sql"
55
"errors"
6+
"iter"
67
"reflect"
8+
"slices"
79
"testing"
810
"time"
911

1012
"go-simpler.org/queries/internal/assert"
1113
. "go-simpler.org/queries/internal/assert/EF"
1214
)
1315

16+
func TestCollect(t *testing.T) {
17+
anErr := errors.New("an error")
18+
19+
tests := map[string]struct {
20+
seq iter.Seq2[int, error]
21+
want []int
22+
wantErr error
23+
}{
24+
"no error": {slices.All([]error{nil, nil}), []int{0, 1}, nil},
25+
"an error": {slices.All([]error{nil, anErr}), nil, anErr},
26+
}
27+
28+
for name, tt := range tests {
29+
t.Run(name, func(t *testing.T) {
30+
got, err := Collect(tt.seq)
31+
assert.IsErr[F](t, err, tt.wantErr)
32+
assert.Equal[E](t, got, tt.want)
33+
})
34+
}
35+
}
36+
1437
func Test_scan(t *testing.T) {
1538
t.Run("no columns", func(t *testing.T) {
1639
fn := func() { _, _ = scan[int](nil, nil) }

tests/integration_test.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7-
"iter"
87
"testing"
98
"time"
109

@@ -91,7 +90,7 @@ func TestIntegration(t *testing.T) {
9190
assert.NoErr[F](t, err)
9291
assert.Equal[E](t, name, TableUsers[0].Name)
9392

94-
names, err := collect(queries.Query[string](ctx, queryer, "SELECT name FROM users"))
93+
names, err := queries.Collect(queries.Query[string](ctx, queryer, "SELECT name FROM users"))
9594
assert.NoErr[F](t, err)
9695
assert.Equal[E](t, names, []string{TableUsers[0].Name, TableUsers[1].Name, TableUsers[2].Name})
9796

@@ -123,17 +122,6 @@ func namedToAny(values []driver.NamedValue) []any {
123122
return args
124123
}
125124

126-
func collect[T any](seq iter.Seq2[T, error]) ([]T, error) {
127-
var ts []T
128-
for t, err := range seq {
129-
if err != nil {
130-
return nil, err
131-
}
132-
ts = append(ts, t)
133-
}
134-
return ts, nil
135-
}
136-
137125
func migrate(ctx context.Context, db *sql.DB) error {
138126
type migration struct {
139127
query string

0 commit comments

Comments
 (0)