Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/zed.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
defaultMaxRetryAttemptDuration = 2 * time.Second
defaultRetryJitterFraction = 0.5
importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships"
exportBulkRoute = "/authzed.api.v1.PermissionsService/ExportBulkRelationships"
)

// NewClient defines an (overridable) means of creating a new client.
Expand Down Expand Up @@ -231,7 +232,10 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti

streamInterceptors := []grpc.StreamClientInterceptor{
zgrpcutil.StreamLogDispatchTrailers,
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute))),
// retrying the bulk import in backup/restore logic is handled manually.
// retrying bulk export is also handled manually, because the default behavior is
// to start at the beginning of the stream, which produces duplicate values.
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute))),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is half of the fix - we don't want to automatically retry ExportBulk requests, because then we're not properly handling restarting the stream.

}

if !cobrautil.MustGetBool(cmd, "skip-version-check") {
Expand Down
106 changes: 72 additions & 34 deletions internal/cmd/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error)
}
}

return
return filteredSchema, nil
}

func hasRelPrefix(rel *v1.Relationship, prefix string) bool {
Expand Down Expand Up @@ -303,7 +303,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
}
}(&err)

c, err := client.NewClient(cmd)
spiceClient, err := client.NewClient(cmd)
if err != nil {
return fmt.Errorf("unable to initialize client: %w", err)
}
Expand All @@ -316,7 +316,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
return fmt.Errorf("error creating backup file encoder: %w", err)
}
} else {
encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile)
encoder, zedToken, err = encoderForNewBackup(cmd, spiceClient, backupFile)
if err != nil {
return err
}
Expand All @@ -343,17 +343,13 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
}

ctx := cmd.Context()
relationshipStream, err := c.ExportBulkRelationships(ctx, req)
if err != nil {
return fmt.Errorf("error exporting relationships: %w", err)
}

relationshipReadStart := time.Now()
tick := time.Tick(5 * time.Second)
bar := console.CreateProgressBar("processing backup")
progressBar := console.CreateProgressBar("processing backup")
var relsFilteredOut, relsProcessed uint64
defer func() {
_ = bar.Finish()
_ = progressBar.Finish()

evt := log.Info().
Uint64("filtered", relsFilteredOut).
Expand All @@ -369,28 +365,8 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
}
}()

for {
if err := ctx.Err(); err != nil {
if isCanceled(err) {
return context.Canceled
}

return fmt.Errorf("aborted backup: %w", err)
}

relsResp, err := relationshipStream.Recv()
if err != nil {
if isCanceled(err) {
return context.Canceled
}

if !errors.Is(err, io.EOF) {
return fmt.Errorf("error receiving relationships: %w", err)
}
break
}

for _, rel := range relsResp.Relationships {
err = takeBackup(ctx, spiceClient, req, func(response *v1.ExportBulkRelationshipsResponse) error {
for _, rel := range response.Relationships {
if hasRelPrefix(rel, prefixFilter) {
if err := encoder.Append(rel); err != nil {
return fmt.Errorf("error storing relationship: %w", err)
Expand All @@ -400,7 +376,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
}

relsProcessed++
if err := bar.Add(1); err != nil {
if err := progressBar.Add(1); err != nil {
return fmt.Errorf("error incrementing progress bar: %w", err)
}

Expand All @@ -419,15 +395,77 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
}
}

if err := writeProgress(progressFile, relsResp); err != nil {
if err := writeProgress(progressFile, response); err != nil {
return err
}
}
return nil
})

backupCompleted = true
return nil
}

func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBulkRelationshipsRequest, processResponse func(*v1.ExportBulkRelationshipsResponse) error) error {
relationshipStream, err := spiceClient.ExportBulkRelationships(ctx, req)
if err != nil {
return fmt.Errorf("error exporting relationships: %w", err)
}
var lastResponse *v1.ExportBulkRelationshipsResponse
for {
if err := ctx.Err(); err != nil {
if isCanceled(err) {
return context.Canceled
}

return fmt.Errorf("aborted backup: %w", err)
}

relsResp, err := relationshipStream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}

if isCanceled(err) {
return context.Canceled
}

if isRetryableError(err) {
// If the error is retryable, we overwrite the existing stream with a new
// stream based on a new request that starts at the cursor location of the
// last received response.

// Clone the request to ensure that we are keeping all other fields the same
newReq := req.CloneVT()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to clone the request? does req.OptionalCursor = lastResponse.AfterResultCursor not work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't strictly need to, but it's not expensive and I don't like mutating function parameters.

cursorToken := "undefined"
if lastResponse != nil && lastResponse.AfterResultCursor != nil {
newReq.OptionalCursor = lastResponse.AfterResultCursor
cursorToken = lastResponse.AfterResultCursor.Token
}

relationshipStream, err = spiceClient.ExportBulkRelationships(ctx, newReq)
log.Info().Err(err).Str("cursor-token", cursorToken).Msg("encountered retryable error, resuming after last known cursor")
// Bounce to the top of the loop
continue
}

return fmt.Errorf("error receiving relationships: %w", err)
}

// Get a reference to the last response in case we need to retry
// starting at its cursor
lastResponse = relsResp

// Process the response using the provided function
err = processResponse(relsResp)
if err != nil {
return err
}
}

return nil
}

// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup
// must be taken.
func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) {
Expand Down
168 changes: 168 additions & 0 deletions internal/cmd/backup_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package cmd

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
Expand All @@ -12,6 +14,7 @@ import (
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand All @@ -22,6 +25,7 @@ import (
"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/storage"
zedtesting "github.com/authzed/zed/internal/testing"
"github.com/authzed/zed/pkg/backupformat"
)

func init() {
Expand Down Expand Up @@ -604,3 +608,167 @@ func TestAddSizeErrInfo(t *testing.T) {
})
}
}

func TestTakeBackupMockWorksAsExpected(t *testing.T) {
rels := []*v1.Relationship{
{
Resource: &v1.ObjectReference{
ObjectType: "resource",
ObjectId: "foo",
},
Relation: "view",
Subject: &v1.SubjectReference{
Object: &v1.ObjectReference{
ObjectType: "user",
ObjectId: "jim",
},
},
},
}
client := &mockClientForBackup{
t: t,
recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){
func() (*v1.ExportBulkRelationshipsResponse, error) {
return &v1.ExportBulkRelationshipsResponse{
Relationships: rels,
}, nil
},
},
}

err := takeBackup(t.Context(), client, &v1.ExportBulkRelationshipsRequest{}, func(response *v1.ExportBulkRelationshipsResponse) error {
require.Len(t, response.Relationships, 1, "expecting 1 rel in the list")
return nil
})
require.NoError(t, err)

client.assertAllRecvCalls()
}

func TestTakeBackupRecoversFromRetryableErrors(t *testing.T) {
firstRels := []*v1.Relationship{
{
Resource: &v1.ObjectReference{
ObjectType: "resource",
ObjectId: "foo",
},
Relation: "view",
Subject: &v1.SubjectReference{
Object: &v1.ObjectReference{
ObjectType: "user",
ObjectId: "jim",
},
},
},
}
cursor := &v1.Cursor{
Token: "a token",
}
secondRels := []*v1.Relationship{
{
Resource: &v1.ObjectReference{
ObjectType: "resource",
ObjectId: "bar",
},
Relation: "view",
Subject: &v1.SubjectReference{
Object: &v1.ObjectReference{
ObjectType: "user",
ObjectId: "jim",
},
},
},
}
client := &mockClientForBackup{
t: t,
recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){
func() (*v1.ExportBulkRelationshipsResponse, error) {
return &v1.ExportBulkRelationshipsResponse{
Relationships: firstRels,
// Need to test that this cursor is supplied
AfterResultCursor: cursor,
}, nil
},
func() (*v1.ExportBulkRelationshipsResponse, error) {
// Return a retryable error
return nil, status.Error(codes.Unavailable, "i fell over")
},
func() (*v1.ExportBulkRelationshipsResponse, error) {
return &v1.ExportBulkRelationshipsResponse{
Relationships: secondRels,
AfterResultCursor: &v1.Cursor{
Token: "some other token",
},
}, nil
},
},
exportCalls: []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest){
// Initial request
func(_ *testing.T, _ *v1.ExportBulkRelationshipsRequest) {},
// The retried request - asserting that it's called with the cursor
func(t *testing.T, req *v1.ExportBulkRelationshipsRequest) {
require.NotNil(t, req)
require.NotNil(t, req.OptionalCursor, "cursor should be set on retry")
require.Equal(t, req.OptionalCursor.Token, cursor.Token, "cursor token does not match expected, got %s", req.OptionalCursor.Token)
},
},
}

actualRels := make([]*v1.Relationship, 0)

err := takeBackup(t.Context(), client, &v1.ExportBulkRelationshipsRequest{}, func(response *v1.ExportBulkRelationshipsResponse) error {
actualRels = append(actualRels, response.Relationships...)
return nil
})
require.NoError(t, err)

require.Len(t, actualRels, 2, "expecting two rels in the realized list")
require.Equal(t, actualRels[0].Resource.ObjectId, "foo")
require.Equal(t, actualRels[1].Resource.ObjectId, "bar")

client.assertAllRecvCalls()
}

type mockClientForBackup struct {
client.Client
grpc.ServerStreamingClient[v1.ExportBulkRelationshipsResponse]
t *testing.T
backupformat.Encoder
recvCalls []func() (*v1.ExportBulkRelationshipsResponse, error)
recvCallIndex int
// exportCalls provides a handle on the calls made to ExportBulkRelationships,
// allowing for assertions to be made against those calls.
exportCalls []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest)
exportCallsIndex int
}

func (m *mockClientForBackup) Recv() (*v1.ExportBulkRelationshipsResponse, error) {
// If we've run through all our calls, return an EOF
if m.recvCallIndex == len(m.recvCalls) {
return nil, io.EOF
}
recvCall := m.recvCalls[m.recvCallIndex]
m.recvCallIndex++
return recvCall()
}

func (m *mockClientForBackup) ExportBulkRelationships(_ context.Context, req *v1.ExportBulkRelationshipsRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[v1.ExportBulkRelationshipsResponse], error) {
if m.exportCalls == nil {
// If the caller doesn't supply exportCalls, pass through
return m, nil
}
if m.exportCallsIndex == len(m.exportCalls) {
// If invoked too many times, fail the test
m.t.FailNow()
return m, nil
}
exportCall := m.exportCalls[m.exportCallsIndex]
m.exportCallsIndex++
exportCall(m.t, req)
return m, nil
}
Comment on lines +745 to +769
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a little awkward/verbose, but it felt like the best way to represent multiple calls to each of these functions.


// assertAllRecvCalls asserts that the number of invocations is as expected
func (m *mockClientForBackup) assertAllRecvCalls() {
require.Equal(m.t, len(m.recvCalls), m.recvCallIndex, "the number of provided recvCalls should match the number of invocations")
}
Loading
Loading