From cea55c73fafbf46f35971eedccb66e35ecdb5a04 Mon Sep 17 00:00:00 2001 From: Tanner Stirrat Date: Tue, 2 Sep 2025 14:09:32 -0600 Subject: [PATCH 1/3] Fix backup retry behavior --- internal/client/client.go | 6 +++++- internal/cmd/backup.go | 20 ++++++++++++++++++++ internal/cmd/backup_test.go | 2 ++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/internal/client/client.go b/internal/client/client.go index 4908319..38e3fc0 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -43,6 +43,7 @@ const ( defaultMaxRetryAttemptDuration = 2 * time.Second defaultRetryJitterFraction = 0.5 importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships" + exportBulkRoute = "/authzed.api.v1.PermissionsService/ExportBulkRelationships" ) // NewClient defines an (overridable) means of creating a new client. @@ -231,7 +232,10 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti streamInterceptors := []grpc.StreamClientInterceptor{ zgrpcutil.StreamLogDispatchTrailers, - selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute))), + // retrying the bulk import in backup/restore logic is handled manually. + // retrying bulk export is also handled manually, because the default behavior is + // to start at the beginning of the stream, which produces duplicate values. + selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute))), } if !cobrautil.MustGetBool(cmd, "skip-version-check") { diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index 00f1641..de6a0a5 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -369,6 +369,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } }() + var lastResponse *v1.ExportBulkRelationshipsResponse for { if err := ctx.Err(); err != nil { if isCanceled(err) { @@ -384,12 +385,31 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { return context.Canceled } + if isRetryableError(err) { + // TODO: do we need to clean up the existing stream in some way? + // TODO: best way to test this? + // If the error is retryable, we overwrite the existing stream with a new + // stream based on a new request that starts at the cursor location of the + // last received response. + relationshipStream, err = c.ExportBulkRelationships(ctx, &v1.ExportBulkRelationshipsRequest{ + OptionalLimit: pageLimit, + OptionalCursor: lastResponse.AfterResultCursor, + }) + log.Info().Err(err).Str("cursor token", lastResponse.AfterResultCursor.Token).Msg("encountered retryable error, resuming stream after token") + // Bounce to the top of the loop + continue + } + if !errors.Is(err, io.EOF) { return fmt.Errorf("error receiving relationships: %w", err) } break } + // Get a reference to the last response in case we need to retry + // starting at its cursor + lastResponse = relsResp + for _, rel := range relsResp.Relationships { if hasRelPrefix(rel, prefixFilter) { if err := encoder.Append(rel); err != nil { diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index 886d46c..0fda553 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -438,6 +438,8 @@ func TestBackupCreateCmdFunc(t *testing.T) { } validateBackupWithFunc(t, f, testSchema, resp.WrittenAt, expectedRels, validationFunc) }) + t.Run("retryable errors pick up where the stream left off", func(_ *testing.T) { + }) } type testConfigStore struct { From 314ec3a589bdf7b20b635c0a85a2c8651ab48bfb Mon Sep 17 00:00:00 2001 From: Tanner Stirrat Date: Mon, 8 Sep 2025 14:10:07 -0600 Subject: [PATCH 2/3] Add tests --- docs/zed.md | 1 + internal/cmd/backup.go | 105 +++++++++++--------- internal/cmd/backup_test.go | 169 +++++++++++++++++++++++++++++++- internal/cmd/restorer_test.go | 16 +-- internal/commands/permission.go | 2 +- internal/commands/util.go | 6 +- internal/decode/decoder.go | 2 +- 7 files changed, 240 insertions(+), 61 deletions(-) diff --git a/docs/zed.md b/docs/zed.md index 28e1647..a2f1cfc 100644 --- a/docs/zed.md +++ b/docs/zed.md @@ -1317,6 +1317,7 @@ zed validate [flags] ### Options ``` + --fail-on-warn treat warnings as errors during validation --force-color force color code output even in non-tty environments --schema-type string force validation according to specific schema syntax ("", "composable", "standard") ``` diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index de6a0a5..fe8f180 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -254,7 +254,7 @@ func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error) } } - return + return filteredSchema, nil } func hasRelPrefix(rel *v1.Relationship, prefix string) bool { @@ -303,7 +303,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } }(&err) - c, err := client.NewClient(cmd) + spiceClient, err := client.NewClient(cmd) if err != nil { return fmt.Errorf("unable to initialize client: %w", err) } @@ -316,7 +316,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { return fmt.Errorf("error creating backup file encoder: %w", err) } } else { - encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile) + encoder, zedToken, err = encoderForNewBackup(cmd, spiceClient, backupFile) if err != nil { return err } @@ -343,17 +343,13 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } ctx := cmd.Context() - relationshipStream, err := c.ExportBulkRelationships(ctx, req) - if err != nil { - return fmt.Errorf("error exporting relationships: %w", err) - } relationshipReadStart := time.Now() tick := time.Tick(5 * time.Second) - bar := console.CreateProgressBar("processing backup") + progressBar := console.CreateProgressBar("processing backup") var relsFilteredOut, relsProcessed uint64 defer func() { - _ = bar.Finish() + _ = progressBar.Finish() evt := log.Info(). Uint64("filtered", relsFilteredOut). @@ -369,6 +365,51 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } }() + err = takeBackup(ctx, spiceClient, req, func(response *v1.ExportBulkRelationshipsResponse) error { + for _, rel := range response.Relationships { + if hasRelPrefix(rel, prefixFilter) { + if err := encoder.Append(rel); err != nil { + return fmt.Errorf("error storing relationship: %w", err) + } + } else { + relsFilteredOut++ + } + + relsProcessed++ + if err := progressBar.Add(1); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + + // progress fallback in case there is no TTY + if !isatty.IsTerminal(os.Stderr.Fd()) { + select { + case <-tick: + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)). + Msg("backup progress") + default: + } + } + } + + if err := writeProgress(progressFile, response); err != nil { + return err + } + return nil + }) + + backupCompleted = true + return nil +} + +func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBulkRelationshipsRequest, processResponse func(*v1.ExportBulkRelationshipsResponse) error) error { + relationshipStream, err := spiceClient.ExportBulkRelationships(ctx, req) + if err != nil { + return fmt.Errorf("error exporting relationships: %w", err) + } var lastResponse *v1.ExportBulkRelationshipsResponse for { if err := ctx.Err(); err != nil { @@ -386,15 +427,16 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } if isRetryableError(err) { - // TODO: do we need to clean up the existing stream in some way? // TODO: best way to test this? // If the error is retryable, we overwrite the existing stream with a new // stream based on a new request that starts at the cursor location of the // last received response. - relationshipStream, err = c.ExportBulkRelationships(ctx, &v1.ExportBulkRelationshipsRequest{ - OptionalLimit: pageLimit, - OptionalCursor: lastResponse.AfterResultCursor, - }) + + // Clone the request to ensure that we are keeping all other fields the same + newReq := req.CloneVT() + newReq.OptionalCursor = lastResponse.AfterResultCursor + + relationshipStream, err = spiceClient.ExportBulkRelationships(ctx, newReq) log.Info().Err(err).Str("cursor token", lastResponse.AfterResultCursor.Token).Msg("encountered retryable error, resuming stream after token") // Bounce to the top of the loop continue @@ -410,41 +452,12 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { // starting at its cursor lastResponse = relsResp - for _, rel := range relsResp.Relationships { - if hasRelPrefix(rel, prefixFilter) { - if err := encoder.Append(rel); err != nil { - return fmt.Errorf("error storing relationship: %w", err) - } - } else { - relsFilteredOut++ - } - - relsProcessed++ - if err := bar.Add(1); err != nil { - return fmt.Errorf("error incrementing progress bar: %w", err) - } - - // progress fallback in case there is no TTY - if !isatty.IsTerminal(os.Stderr.Fd()) { - select { - case <-tick: - log.Info(). - Uint64("filtered", relsFilteredOut). - Uint64("processed", relsProcessed). - Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). - Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)). - Msg("backup progress") - default: - } - } - } - - if err := writeProgress(progressFile, relsResp); err != nil { + // Process the response using the provided function + err = processResponse(relsResp) + if err != nil { return err } } - - backupCompleted = true return nil } diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index 0fda553..d47ca6f 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -1,9 +1,11 @@ package cmd import ( + "context" "encoding/json" "errors" "fmt" + "io" "os" "path/filepath" "strings" @@ -12,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -22,6 +25,7 @@ import ( "github.com/authzed/zed/internal/client" "github.com/authzed/zed/internal/storage" zedtesting "github.com/authzed/zed/internal/testing" + "github.com/authzed/zed/pkg/backupformat" ) func init() { @@ -438,8 +442,6 @@ func TestBackupCreateCmdFunc(t *testing.T) { } validateBackupWithFunc(t, f, testSchema, resp.WrittenAt, expectedRels, validationFunc) }) - t.Run("retryable errors pick up where the stream left off", func(_ *testing.T) { - }) } type testConfigStore struct { @@ -606,3 +608,166 @@ func TestAddSizeErrInfo(t *testing.T) { }) } } + +func TestTakeBackupMockWorksAsExpected(t *testing.T) { + rels := []*v1.Relationship{ + { + Resource: &v1.ObjectReference{ + ObjectType: "resource", + ObjectId: "foo", + }, + Relation: "view", + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: "user", + ObjectId: "jim", + }, + }, + }, + } + client := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: rels, + }, nil + }, + }, + } + + err := takeBackup(t.Context(), client, &v1.ExportBulkRelationshipsRequest{}, func(response *v1.ExportBulkRelationshipsResponse) error { + require.Len(t, response.Relationships, 1, "expecting 1 rel in the list") + return nil + }) + require.NoError(t, err) + + client.assertAllRecvCalls() +} + +func TestTakeBackupRecoversFromRetryableErrors(t *testing.T) { + firstRels := []*v1.Relationship{ + { + Resource: &v1.ObjectReference{ + ObjectType: "resource", + ObjectId: "foo", + }, + Relation: "view", + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: "user", + ObjectId: "jim", + }, + }, + }, + } + cursor := &v1.Cursor{ + Token: "an token", + } + secondRels := []*v1.Relationship{ + { + Resource: &v1.ObjectReference{ + ObjectType: "resource", + ObjectId: "bar", + }, + Relation: "view", + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: "user", + ObjectId: "jim", + }, + }, + }, + } + client := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: firstRels, + // Need to test that this cursor is supplied + AfterResultCursor: cursor, + }, nil + }, + func() (*v1.ExportBulkRelationshipsResponse, error) { + // Return a retryable error + return nil, status.Error(codes.Unavailable, "i fell over") + }, + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: secondRels, + AfterResultCursor: &v1.Cursor{ + Token: "some other token", + }, + }, nil + }, + }, + exportCalls: []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest){ + // Initial request + func(_ *testing.T, _ *v1.ExportBulkRelationshipsRequest) { + }, + // The retried request - asserting that it's called with the cursor + func(t *testing.T, req *v1.ExportBulkRelationshipsRequest) { + require.Equal(t, req.OptionalCursor.Token, cursor.Token, "cursor token does not match expected") + }, + }, + } + + actualRels := make([]*v1.Relationship, 0) + + err := takeBackup(t.Context(), client, &v1.ExportBulkRelationshipsRequest{}, func(response *v1.ExportBulkRelationshipsResponse) error { + actualRels = append(actualRels, response.Relationships...) + return nil + }) + require.NoError(t, err) + + require.Len(t, actualRels, 2, "expecting two rels in the realized list") + require.Equal(t, actualRels[0].Resource.ObjectId, "foo") + require.Equal(t, actualRels[1].Resource.ObjectId, "bar") + + client.assertAllRecvCalls() +} + +type mockClientForBackup struct { + client.Client + grpc.ServerStreamingClient[v1.ExportBulkRelationshipsResponse] + t *testing.T + backupformat.Encoder + recvCalls []func() (*v1.ExportBulkRelationshipsResponse, error) + recvCallIndex int + // exportCalls provides a handle on the calls made to ExportBulkRelationships, + // allowing for assertions to be made against those calls. + exportCalls []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest) + exportCallsIndex int +} + +func (m *mockClientForBackup) Recv() (*v1.ExportBulkRelationshipsResponse, error) { + // If we've run through all our calls, return an EOF + if m.recvCallIndex == len(m.recvCalls) { + return nil, io.EOF + } + recvCall := m.recvCalls[m.recvCallIndex] + m.recvCallIndex++ + return recvCall() +} + +func (m *mockClientForBackup) ExportBulkRelationships(_ context.Context, req *v1.ExportBulkRelationshipsRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[v1.ExportBulkRelationshipsResponse], error) { + if m.exportCalls == nil { + // If the caller doesn't supply exportCalls, pass through + return m, nil + } + if m.exportCallsIndex == len(m.exportCalls) { + // If invoked too many times, fail the test + m.t.FailNow() + return m, nil + } + exportCall := m.exportCalls[m.exportCallsIndex] + m.exportCallsIndex++ + exportCall(m.t, req) + return m, nil +} + +// assertAllRecvCalls asserts that the number of invocations is as expected +func (m *mockClientForBackup) assertAllRecvCalls() { + require.Equal(m.t, len(m.recvCalls), m.recvCallIndex, "the number of provided recvCalls should match the number of invocations") +} diff --git a/internal/cmd/restorer_test.go b/internal/cmd/restorer_test.go index f081634..77fadbc 100644 --- a/internal/cmd/restorer_test.go +++ b/internal/cmd/restorer_test.go @@ -87,7 +87,7 @@ func TestRestorer(t *testing.T) { remainderBatch = true } - c := &mockClient{ + c := &mockClientForRestore{ t: t, schema: testSchema, remainderBatch: remainderBatch, @@ -183,9 +183,9 @@ func TestRestorer(t *testing.T) { } } -type mockClient struct { +type mockClientForRestore struct { client.Client - v1.PermissionsService_ImportBulkRelationshipsClient + grpc.ClientStreamingClient[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse] t *testing.T schema string remainderBatch bool @@ -204,7 +204,7 @@ type mockClient struct { touchErrors []error } -func (m *mockClient) Send(req *v1.ImportBulkRelationshipsRequest) error { +func (m *mockClientForRestore) Send(req *v1.ImportBulkRelationshipsRequest) error { m.receivedBatches++ m.receivedRels += uint(len(req.Relationships)) m.lastReceivedBatch = req.Relationships @@ -227,7 +227,7 @@ func (m *mockClient) Send(req *v1.ImportBulkRelationshipsRequest) error { return nil } -func (m *mockClient) WriteRelationships(_ context.Context, in *v1.WriteRelationshipsRequest, _ ...grpc.CallOption) (*v1.WriteRelationshipsResponse, error) { +func (m *mockClientForRestore) WriteRelationships(_ context.Context, in *v1.WriteRelationshipsRequest, _ ...grpc.CallOption) (*v1.WriteRelationshipsResponse, error) { m.touchedBatches++ m.touchedRels += uint(len(in.Updates)) if m.touchedBatches <= uint(len(m.touchErrors)) { @@ -237,7 +237,7 @@ func (m *mockClient) WriteRelationships(_ context.Context, in *v1.WriteRelations return &v1.WriteRelationshipsResponse{}, nil } -func (m *mockClient) CloseAndRecv() (*v1.ImportBulkRelationshipsResponse, error) { +func (m *mockClientForRestore) CloseAndRecv() (*v1.ImportBulkRelationshipsResponse, error) { m.receivedCommits++ lastBatch := m.lastReceivedBatch defer func() { m.lastReceivedBatch = nil }() @@ -249,11 +249,11 @@ func (m *mockClient) CloseAndRecv() (*v1.ImportBulkRelationshipsResponse, error) return &v1.ImportBulkRelationshipsResponse{NumLoaded: uint64(len(lastBatch))}, nil } -func (m *mockClient) ImportBulkRelationships(_ context.Context, _ ...grpc.CallOption) (v1.PermissionsService_ImportBulkRelationshipsClient, error) { +func (m *mockClientForRestore) ImportBulkRelationships(_ context.Context, _ ...grpc.CallOption) (grpc.ClientStreamingClient[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse], error) { return m, nil } -func (m *mockClient) WriteSchema(_ context.Context, wsr *v1.WriteSchemaRequest, _ ...grpc.CallOption) (*v1.WriteSchemaResponse, error) { +func (m *mockClientForRestore) WriteSchema(_ context.Context, wsr *v1.WriteSchemaRequest, _ ...grpc.CallOption) (*v1.WriteSchemaResponse, error) { require.Equal(m.t, m.schema, wsr.Schema, "unexpected schema in write schema request") return &v1.WriteSchemaResponse{}, nil } diff --git a/internal/commands/permission.go b/internal/commands/permission.go index 58b022c..db7b6c1 100644 --- a/internal/commands/permission.go +++ b/internal/commands/permission.go @@ -65,7 +65,7 @@ func consistencyFromCmd(cmd *cobra.Command) (c *v1.Consistency, err error) { if c == nil { c = &v1.Consistency{Requirement: &v1.Consistency_MinimizeLatency{MinimizeLatency: true}} } - return + return c, err } func RegisterPermissionCmd(rootCmd *cobra.Command) *cobra.Command { diff --git a/internal/commands/util.go b/internal/commands/util.go index 6d5b4da..dde37bb 100644 --- a/internal/commands/util.go +++ b/internal/commands/util.go @@ -22,20 +22,20 @@ import ( func ParseSubject(s string) (namespace, id, relation string, err error) { err = stringz.SplitExact(s, ":", &namespace, &id) if err != nil { - return + return namespace, id, relation, err } err = stringz.SplitExact(id, "#", &id, &relation) if err != nil { relation = "" err = nil } - return + return namespace, id, relation, err } // ParseType parses a type reference of the form `namespace#relaion`. func ParseType(s string) (namespace, relation string) { namespace, relation, _ = strings.Cut(s, "#") - return + return namespace, relation } // GetCaveatContext returns the entered caveat caveat, if any. diff --git a/internal/decode/decoder.go b/internal/decode/decoder.go index 2d30c86..d493030 100644 --- a/internal/decode/decoder.go +++ b/internal/decode/decoder.go @@ -47,7 +47,7 @@ func DecoderForURL(u *url.URL) (d Func, err error) { default: err = fmt.Errorf("%s scheme not supported", s) } - return + return d, err } func fileDecoder(u *url.URL) Func { From a64c6238d1d89e690815e599be76c6ecae615b2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Tue, 9 Sep 2025 14:28:28 +0100 Subject: [PATCH 3/3] fixes potential SIGSEV when failing on the first request since at that point we would have not received any response and we wouldn't have a token. Also made sure to verify the tests failed if I removed the usage of the last known token, which led me to improve the assertions, as the tests were SIGSEV'ing too. Also inverted the !errors.Is(err, io.EOF) check, which is more idiomatic --- internal/cmd/backup.go | 19 ++++++++++++------- internal/cmd/backup_test.go | 9 +++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index fe8f180..7b2ca3b 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -422,30 +422,34 @@ func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBu relsResp, err := relationshipStream.Recv() if err != nil { + if errors.Is(err, io.EOF) { + break + } + if isCanceled(err) { return context.Canceled } if isRetryableError(err) { - // TODO: best way to test this? // If the error is retryable, we overwrite the existing stream with a new // stream based on a new request that starts at the cursor location of the // last received response. // Clone the request to ensure that we are keeping all other fields the same newReq := req.CloneVT() - newReq.OptionalCursor = lastResponse.AfterResultCursor + cursorToken := "undefined" + if lastResponse != nil && lastResponse.AfterResultCursor != nil { + newReq.OptionalCursor = lastResponse.AfterResultCursor + cursorToken = lastResponse.AfterResultCursor.Token + } relationshipStream, err = spiceClient.ExportBulkRelationships(ctx, newReq) - log.Info().Err(err).Str("cursor token", lastResponse.AfterResultCursor.Token).Msg("encountered retryable error, resuming stream after token") + log.Info().Err(err).Str("cursor-token", cursorToken).Msg("encountered retryable error, resuming after last known cursor") // Bounce to the top of the loop continue } - if !errors.Is(err, io.EOF) { - return fmt.Errorf("error receiving relationships: %w", err) - } - break + return fmt.Errorf("error receiving relationships: %w", err) } // Get a reference to the last response in case we need to retry @@ -458,6 +462,7 @@ func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBu return err } } + return nil } diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index d47ca6f..6085894 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -662,7 +662,7 @@ func TestTakeBackupRecoversFromRetryableErrors(t *testing.T) { }, } cursor := &v1.Cursor{ - Token: "an token", + Token: "a token", } secondRels := []*v1.Relationship{ { @@ -704,11 +704,12 @@ func TestTakeBackupRecoversFromRetryableErrors(t *testing.T) { }, exportCalls: []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest){ // Initial request - func(_ *testing.T, _ *v1.ExportBulkRelationshipsRequest) { - }, + func(_ *testing.T, _ *v1.ExportBulkRelationshipsRequest) {}, // The retried request - asserting that it's called with the cursor func(t *testing.T, req *v1.ExportBulkRelationshipsRequest) { - require.Equal(t, req.OptionalCursor.Token, cursor.Token, "cursor token does not match expected") + require.NotNil(t, req) + require.NotNil(t, req.OptionalCursor, "cursor should be set on retry") + require.Equal(t, req.OptionalCursor.Token, cursor.Token, "cursor token does not match expected, got %s", req.OptionalCursor.Token) }, }, }