Skip to content

REP-6492 Switch to $sampleRate-style partitioning #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 21, 2025
25 changes: 20 additions & 5 deletions internal/partitions/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package partitions
import (
"fmt"
"slices"
"time"

"github.com/10gen/migration-verifier/internal/util"
"github.com/10gen/migration-verifier/mbson"
Expand All @@ -13,13 +14,27 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
)

func ForPartitionAggregation(coll *mongo.Collection) *mongo.Collection {
return coll.
Database().
Collection(
coll.Name(),
options.Collection().
SetReadConcern(readconcern.Available()).
SetReadPreference(
readpref.SecondaryPreferred(readpref.WithMaxStaleness(90*time.Second))),
)
}

// PartitionKey represents the _id of a partition document stored in the destination.
type PartitionKey struct {
SourceUUID util.UUID `bson:"srcUUID"`
MongosyncID string `bson:"id"`
Lower any `bson:"lowerBound"`
SourceUUID util.UUID `bson:"srcUUID"`
Lower any `bson:"lowerBound"`
}

// Namespace stores the database and collection name of the namespace being copied.
Expand All @@ -46,8 +61,8 @@ type Partition struct {
// String returns a string representation of the partition.
func (p *Partition) String() string {
return fmt.Sprintf(
"{db: %s, coll: %s, collUUID: %s, mongosyncID: %s, lower: %s, upper: %s}",
p.Ns.DB, p.Ns.Coll, p.Key.SourceUUID, p.Key.MongosyncID, p.GetLowerBoundString(), p.GetUpperBoundString())
"{db: %s, coll: %s, collUUID: %s, lower: %s, upper: %s}",
p.Ns.DB, p.Ns.Coll, p.Key.SourceUUID, p.GetLowerBoundString(), p.GetUpperBoundString())
}

// GetLowerBoundString returns the string representation of this partition's lower bound.
Expand Down
5 changes: 2 additions & 3 deletions internal/partitions/partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ func getFilterFromFindOptions(opts bson.D) any {
func (suite *UnitTestSuite) makeTestPartition() (Partition, bson.D) {
partition := Partition{
Key: PartitionKey{
SourceUUID: util.NewUUID(),
Lower: primitive.ObjectID([12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
MongosyncID: "",
SourceUUID: util.NewUUID(),
Lower: primitive.ObjectID([12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
},
Ns: &Namespace{DB: "testDB", Coll: "testColl"},
Upper: primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}),
Expand Down
56 changes: 16 additions & 40 deletions internal/partitions/partitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package partitions
import (
"context"
"fmt"
"math/rand"

"github.com/10gen/migration-verifier/internal/logger"
"github.com/10gen/migration-verifier/internal/reportutils"
Expand All @@ -15,6 +14,7 @@ import (
"github.com/pkg/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

const (
Expand All @@ -29,13 +29,13 @@ const (
// possible that $sample does a collection scan if the number of documents increases very quickly, but
// that should be very rare.
//
defaultSampleRate = 0.04
sampleRate = 0.04

//
// The minimum number of documents $sample requires in order to use a pseudo-random cursor.
// See: https://docs.mongodb.com/manual/reference/operator/aggregation/sample/#behavior
//
defaultSampleMinNumDocs = 101
sampleMinNumDocs = 101

//
// The maximum number of documents to sample per partition. Previously this is set to 10.
Expand Down Expand Up @@ -70,13 +70,6 @@ const (
defaultPartitionSizeInBytes = 400 * 1024 * 1024 // = 400 MB
)

// Replicator contains the id of a mongosync replicator.
// It is used here to avoid changing the interface of partitioning (from the mongosync version)
// overmuch.
type Replicator struct {
ID string `bson:"id"`
}

// Partitions is a slice of partitions.
type Partitions struct {
logger *logger.Logger
Expand Down Expand Up @@ -123,7 +116,6 @@ func PartitionCollectionWithSize(
ctx context.Context,
uuidEntry *uuidutil.NamespaceAndUUID,
srcClient *mongo.Client,
replicatorList []Replicator,
subLogger *logger.Logger,
partitionSizeInBytes int64,
globalFilter bson.D,
Expand All @@ -138,13 +130,10 @@ func PartitionCollectionWithSize(
partitionSizeInBytes = defaultPartitionSizeInBytes
}

partitions, docCount, byteCount, err := PartitionCollectionWithParameters(
partitions, docCount, byteCount, err := partitionCollectionWithParameters(
ctx,
uuidEntry,
srcClient,
replicatorList,
defaultSampleRate,
defaultSampleMinNumDocs,
partitionSizeInBytes,
subLogger,
globalFilter,
Expand All @@ -157,13 +146,10 @@ func PartitionCollectionWithSize(
Str("filter", fmt.Sprintf("%+v", globalFilter)).
Msg("Timed out while partitioning with filter. Continuing by partitioning without the filter.")

return PartitionCollectionWithParameters(
return partitionCollectionWithParameters(
ctx,
uuidEntry,
srcClient,
replicatorList,
defaultSampleRate,
defaultSampleMinNumDocs,
partitionSizeInBytes,
subLogger,
nil,
Expand All @@ -173,17 +159,14 @@ func PartitionCollectionWithSize(
return partitions, docCount, byteCount, err
}

// PartitionCollectionWithParameters is the implementation for
// partitionCollectionWithParameters is the implementation for
// PartitionCollection. It is only directly used in integration tests.
// See PartitionCollectionWithParameters for a description of inputs
// See partitionCollectionWithParameters for a description of inputs
// & outputs. (Alas, the parameter order differs slightly here …)
func PartitionCollectionWithParameters(
func partitionCollectionWithParameters(
ctx context.Context,
uuidEntry *uuidutil.NamespaceAndUUID,
srcClient *mongo.Client,
replicatorList []Replicator,
sampleRate float64,
sampleMinNumDocs int,
partitionSizeInBytes int64,
subLogger *logger.Logger,
globalFilter bson.D,
Expand Down Expand Up @@ -315,9 +298,6 @@ func PartitionCollectionWithParameters(
Msg("_id bounds should outnumber partitions by 1.")
}

// Choose a random index to start to avoid over-assigning partitions to a specific replicator.
// rand.Int() generates non-negative integers only.
replIndex := rand.Int() % len(replicatorList)
subLogger.Debug().
Int("numPartitions", len(allIDBounds)-1).
Str("namespace", uuidEntry.DBName+"."+uuidEntry.CollName).
Expand All @@ -329,9 +309,8 @@ func PartitionCollectionWithParameters(

for i := 0; i < len(allIDBounds)-1; i++ {
partitionKey := PartitionKey{
SourceUUID: uuidEntry.UUID,
MongosyncID: replicatorList[replIndex].ID,
Lower: allIDBounds[i],
SourceUUID: uuidEntry.UUID,
Lower: allIDBounds[i],
}
partition := &Partition{
Key: partitionKey,
Expand All @@ -340,8 +319,6 @@ func PartitionCollectionWithParameters(
IsCapped: isCapped,
}
partitions = append(partitions, partition)

replIndex = (replIndex + 1) % len(replicatorList)
}

return partitions, types.DocumentCount(collDocCount), types.ByteCount(collSizeInBytes), nil
Expand Down Expand Up @@ -641,13 +618,12 @@ func getMidIDBounds(
WithCallback(
func(ctx context.Context, ri *retry.FuncInfo) error {
ri.Log(logger.Logger, "aggregate", "source", srcDB.Name(), collName, "Retrieving mid _id partition bounds using $sample.")
cursor, cmdErr :=
srcDB.RunCommandCursor(ctx, bson.D{
{"aggregate", collName},
{"pipeline", pipeline},
{"allowDiskUse", true},
{"cursor", bson.D{}},
})
cursor, cmdErr := ForPartitionAggregation(srcDB.Collection(collName)).
Aggregate(
ctx,
pipeline,
options.Aggregate().SetAllowDiskUse(true),
)

if cmdErr != nil {
return errors.Wrapf(cmdErr, "failed to $sample and $bucketAuto documents for source namespace '%s.%s'", srcDB.Name(), collName)
Expand Down
4 changes: 2 additions & 2 deletions internal/reportutils/reportutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func BytesToUnit[T num16Plus](count T, unit DataUnit) string {
panic(fmt.Sprintf("Missing unit in unitSize: %s", unit))
}

return FmtReal(util.Divide(count, myUnitSize))
return FmtReal(util.DivideToF64(count, myUnitSize))
}

// FmtReal provides a standard formatting of real numbers, with a consistent
Expand Down Expand Up @@ -128,7 +128,7 @@ func roundFloat(val float64, precision uint) float64 {
}

func fmtQuotient[T, U realNum](dividend T, divisor U) string {
return FmtReal(util.Divide(dividend, divisor))
return FmtReal(util.DivideToF64(dividend, divisor))
}

// FmtPercent returns a stringified percentage without a trailing `%`,
Expand Down
4 changes: 2 additions & 2 deletions internal/util/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package util

import "github.com/10gen/migration-verifier/internal/types"

// Divide is syntactic sugar around float64(numerator) / float64(denominator).
func Divide[N types.RealNumber, D types.RealNumber](numerator N, denominator D) float64 {
// DivideToF64 is syntactic sugar around float64(numerator) / float64(denominator).
func DivideToF64[N types.RealNumber, D types.RealNumber](numerator N, denominator D) float64 {
return float64(numerator) / float64(denominator)
}
65 changes: 39 additions & 26 deletions internal/verifier/migration_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,21 +757,17 @@ func (verifier *Verifier) partitionAndInspectNamespace(ctx context.Context, name
return nil, nil, 0, 0, err
}

// The partitioner doles out ranges to replicators; we don't use that functionality so we just pass
// one "replicator".
replicator1 := partitions.Replicator{ID: "verifier"}
replicators := []partitions.Replicator{replicator1}
partitionList, srcDocs, srcBytes, err := partitions.PartitionCollectionWithSize(
ctx, namespaceAndUUID, verifier.srcClient, replicators, verifier.logger, verifier.partitionSizeInBytes, verifier.globalFilter)
ctx, namespaceAndUUID, verifier.srcClient, verifier.logger, verifier.partitionSizeInBytes, verifier.globalFilter)
if err != nil {
return nil, nil, 0, 0, err
}
// TODO: Test the empty collection (which returns no partitions)
if len(partitionList) == 0 {
partitionList = []*partitions.Partition{{
Key: partitions.PartitionKey{
SourceUUID: namespaceAndUUID.UUID,
MongosyncID: "verifier"},
SourceUUID: namespaceAndUUID.UUID,
},
Ns: &partitions.Namespace{
DB: namespaceAndUUID.DBName,
Coll: namespaceAndUUID.CollName}}}
Expand Down Expand Up @@ -1228,33 +1224,50 @@ func (verifier *Verifier) verifyMetadataAndPartitionCollection(
// matches between soruce & destination. Now we can partition the collection.

if task.Generation == 0 {
partitions, shardKeys, docsCount, bytesCount, err := verifier.partitionAndInspectNamespace(ctx, srcNs)
if err != nil {
return errors.Wrapf(
err,
"failed to partition collection %#q",
srcNs,
)
var partitionsCount int
var docsCount types.DocumentCount
var bytesCount types.ByteCount

if verifier.srcHasSampleRate() {
var err error
partitionsCount, docsCount, bytesCount, err = verifier.createPartitionTasksWithSampleRate(ctx, task)
if err != nil {
return errors.Wrapf(err, "partitioning %#q via $sampleRate", srcNs)
}
} else {
verifier.logger.Warn().
Msg("Source MongoDB version lacks $sampleRate. Using legacy partitioning logic. This may cause imbalanced partitions, which will impede performance.")

var partitions []*partitions.Partition
var shardKeys []string

partitions, shardKeys, docsCount, bytesCount, err = verifier.partitionAndInspectNamespace(ctx, srcNs)
if err != nil {
return errors.Wrapf(err, "partitioning %#q via $sample", srcNs)
}

partitionsCount = len(partitions)

for _, partition := range partitions {
_, err := verifier.InsertPartitionVerificationTask(ctx, partition, shardKeys, dstNs)
if err != nil {
return errors.Wrapf(
err,
"failed to insert a partition task for namespace %#q",
srcNs,
)
}
}
}

verifier.logger.Debug().
Int("workerNum", workerNum).
Str("namespace", srcNs).
Int("partitionsCount", len(partitions)).
Int("partitionsCount", partitionsCount).
Msg("Divided collection into partitions.")

task.SourceDocumentCount = docsCount
task.SourceByteCount = bytesCount

for _, partition := range partitions {
_, err := verifier.InsertPartitionVerificationTask(ctx, partition, shardKeys, dstNs)
if err != nil {
return errors.Wrapf(
err,
"failed to insert a partition task for namespace %#q",
srcNs,
)
}
}
}

if task.Status == verificationTaskProcessing {
Expand Down
50 changes: 50 additions & 0 deletions internal/verifier/migration_verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,56 @@ func TestIntegration(t *testing.T) {
suite.Run(t, testSuite)
}

func (suite *IntegrationTestSuite) TestPartitionEmptyCollection() {
ctx := suite.Context()
require := require.New(suite.T())

verifier := suite.BuildVerifier()

db := suite.srcMongoClient.Database(suite.DBNameForTest())
collName := "stuff"
require.NoError(db.CreateCollection(ctx, collName))

task := &VerificationTask{
PrimaryKey: primitive.NewObjectID(),
Generation: 0,
Status: verificationTaskAdded,
Type: verificationTaskVerifyCollection,
QueryFilter: QueryFilter{
Namespace: db.Name() + "." + collName,
To: db.Name() + "." + collName,
},
}

partitions, docs, bytes, err := verifier.createPartitionTasksWithSampleRate(ctx, task)
require.NoError(err, "should partition collection")

assert.EqualValues(suite.T(), 1, partitions, "should be 1 partition")
assert.Zero(suite.T(), docs, "should be 0 docs")
assert.Zero(suite.T(), bytes, "should be 0 bytes")

taskOpt, err := verifier.FindNextVerifyTaskAndUpdate(ctx)
require.NoError(err, "should look up task")

foundTask, gotTask := taskOpt.Get()
require.True(gotTask, "should find task")

require.Equal(verificationTaskVerifyDocuments, foundTask.Type, "task type")
assert.Equal(
suite.T(),
primitive.MinKey{},
foundTask.QueryFilter.Partition.Key.Lower,
"min bound",
)

assert.Equal(
suite.T(),
primitive.MaxKey{},
foundTask.QueryFilter.Partition.Upper,
"max bound",
)
}

func (suite *IntegrationTestSuite) TestProcessVerifyTask_Failure() {
verifier := suite.BuildVerifier()
ctx := suite.Context()
Expand Down
Loading