diff --git a/pkg/sql/colexec/group/exec.go b/pkg/sql/colexec/group/exec.go index 700d5aa1ccadd..0b52f45df11ee 100644 --- a/pkg/sql/colexec/group/exec.go +++ b/pkg/sql/colexec/group/exec.go @@ -63,6 +63,14 @@ func (group *Group) Prepare(proc *process.Process) (err error) { if err = group.prepareGroup(proc); err != nil { return err } + + if group.SpillManager == nil { + group.SpillManager = NewMemorySpillManager() + } + if group.SpillThreshold <= 0 { + group.SpillThreshold = 256 * 1024 //TODO configurable + } + return group.PrepareProjection(proc) } @@ -188,6 +196,12 @@ func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, e for { if group.ctr.state == vm.Eval { + if len(group.ctr.spilledStates) > 0 { + if err := group.mergeSpilledResults(proc); err != nil { + return nil, err + } + } + if group.ctr.result1.IsEmpty() { group.ctr.state = vm.End return nil, nil @@ -203,7 +217,7 @@ func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, e group.ctr.state = vm.Eval if group.ctr.isDataSourceEmpty() && len(group.Exprs) == 0 { - if err = group.generateInitialResult1WithoutGroupBy(proc); err != nil { + if err := group.generateInitialResult1WithoutGroupBy(proc); err != nil { return nil, err } group.ctr.result1.ToPopped[0].SetRowCount(1) @@ -215,6 +229,16 @@ func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, e } group.ctr.dataSourceIsEmpty = false + + // Check if we need to spill before processing this batch + group.updateMemoryUsage(proc) + if group.shouldSpill() { + if err := group.spillPartialResults(proc); err != nil { + return nil, err + } + continue + } + if err = group.consumeBatchToGetFinalResult(proc, res); err != nil { return nil, err } @@ -245,7 +269,6 @@ func (group *Group) consumeBatchToGetFinalResult( switch group.ctr.mtyp { case H0: - // without group by. if group.ctr.result1.IsEmpty() { if err := group.generateInitialResult1WithoutGroupBy(proc); err != nil { return err @@ -260,7 +283,6 @@ func (group *Group) consumeBatchToGetFinalResult( } default: - // with group by. if group.ctr.result1.IsEmpty() { err := group.ctr.hr.BuildHashTable(false, group.ctr.mtyp == HStr, group.ctr.keyNullable, group.PreAllocSize) if err != nil { @@ -315,6 +337,16 @@ func (group *Group) consumeBatchToGetFinalResult( } } + // Update memory usage after processing the batch + group.updateMemoryUsage(proc) + + // Check if we need to spill after processing this batch + if group.shouldSpill() { + if err := group.spillPartialResults(proc); err != nil { + return err + } + } + return nil } diff --git a/pkg/sql/colexec/group/exec_test.go b/pkg/sql/colexec/group/exec_test.go index 5e53f82e8aef2..6f5012d400c4e 100644 --- a/pkg/sql/colexec/group/exec_test.go +++ b/pkg/sql/colexec/group/exec_test.go @@ -46,6 +46,10 @@ func (h *hackAggExecToTest) GetOptResult() aggexec.SplitResult { return nil } +func (h *hackAggExecToTest) Size() int64 { + return 0 +} + func (h *hackAggExecToTest) GroupGrow(more int) error { h.groupNumber += more return nil diff --git a/pkg/sql/colexec/group/execctx.go b/pkg/sql/colexec/group/execctx.go index 22fc688d8392d..01c80338aeb4e 100644 --- a/pkg/sql/colexec/group/execctx.go +++ b/pkg/sql/colexec/group/execctx.go @@ -123,7 +123,7 @@ type GroupResultBuffer struct { } func (buf *GroupResultBuffer) IsEmpty() bool { - return cap(buf.ToPopped) == 0 + return len(buf.ToPopped) == 0 } func (buf *GroupResultBuffer) InitOnlyAgg(chunkSize int, aggList []aggexec.AggFuncExec) { diff --git a/pkg/sql/colexec/group/group_spill.go b/pkg/sql/colexec/group/group_spill.go new file mode 100644 index 0000000000000..f397e0f4cb311 --- /dev/null +++ b/pkg/sql/colexec/group/group_spill.go @@ -0,0 +1,473 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "fmt" + + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/logutil" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/matrixorigin/matrixone/pkg/vm/process" + "go.uber.org/zap" +) + +func (group *Group) shouldSpill() bool { + shouldSpill := group.SpillThreshold > 0 && + group.SpillManager != nil && + group.ctr.currentMemUsage > group.SpillThreshold && + len(group.ctr.result1.AggList) > 0 && + len(group.ctr.result1.ToPopped) > 0 + + logutil.Info("[SPILL] shouldSpill", + zap.Any("threshold", group.SpillThreshold), + zap.Any("current usage", group.ctr.currentMemUsage), + zap.Any("agg list len", len(group.ctr.result1.AggList)), + zap.Any("to popped len", len(group.ctr.result1.ToPopped)), + ) + + if shouldSpill { + logutil.Info("[SPILL] Group operator triggering spill", + zap.Int64("current_memory_usage", group.ctr.currentMemUsage), + zap.Int64("spill_threshold", group.SpillThreshold), + zap.Int("agg_count", len(group.ctr.result1.AggList)), + zap.Int("batch_count", len(group.ctr.result1.ToPopped))) + } + + return shouldSpill +} + +func (group *Group) updateMemoryUsage(proc *process.Process) { + usage := int64(0) + + if !group.ctr.hr.IsEmpty() && group.ctr.hr.Hash != nil { + usage += int64(group.ctr.hr.Hash.Size()) + } + + for _, bat := range group.ctr.result1.ToPopped { + if bat != nil { + usage += int64(bat.Size()) + } + } + + for _, agg := range group.ctr.result1.AggList { + if agg != nil { + usage += agg.Size() + } + } + + previousUsage := group.ctr.currentMemUsage + group.ctr.currentMemUsage = usage + + if usage > previousUsage && usage > group.SpillThreshold/2 { + logutil.Info("[SPILL] Group operator memory usage update", + zap.Int64("previous_usage", previousUsage), + zap.Int64("current_usage", usage), + zap.Int64("spill_threshold", group.SpillThreshold)) + } +} + +func (group *Group) spillPartialResults(proc *process.Process) error { + if len(group.ctr.result1.AggList) == 0 || len(group.ctr.result1.ToPopped) == 0 { + logutil.Info("[SPILL] Group operator spill called but no data to spill") + return nil + } + + logutil.Info("[SPILL] Group operator starting spill operation", + zap.Int64("memory_usage", group.ctr.currentMemUsage), + zap.Int64("spill_threshold", group.SpillThreshold), + zap.Int("agg_count", len(group.ctr.result1.AggList))) + + marshaledAggStates := make([][]byte, len(group.ctr.result1.AggList)) + for i, agg := range group.ctr.result1.AggList { + if agg != nil { + marshaledData, err := aggexec.MarshalAggFuncExec(agg) + if err != nil { + logutil.Error("[SPILL] Group operator failed to marshal aggregator", + zap.Int("agg_index", i), zap.Error(err)) + return err + } + marshaledAggStates[i] = marshaledData + } + } + + totalGroups := 0 + for _, bat := range group.ctr.result1.ToPopped { + if bat != nil { + totalGroups += bat.RowCount() + } + } + + if totalGroups == 0 { + logutil.Info("[SPILL] Group operator spill found no groups to spill") + for _, agg := range group.ctr.result1.AggList { + if agg != nil { + agg.Free() + } + } + group.ctr.result1.AggList = nil + group.ctr.result1.ToPopped = group.ctr.result1.ToPopped[:0] + return nil + } + + var groupVecs []*vector.Vector + var groupVecTypes []types.Type + if len(group.ctr.result1.ToPopped) > 0 && group.ctr.result1.ToPopped[0] != nil { + numGroupByCols := len(group.ctr.result1.ToPopped[0].Vecs) + groupVecs = make([]*vector.Vector, numGroupByCols) + groupVecTypes = make([]types.Type, numGroupByCols) + + for i := 0; i < numGroupByCols; i++ { + if len(group.ctr.result1.ToPopped[0].Vecs) > i && group.ctr.result1.ToPopped[0].Vecs[i] != nil { + vecType := *group.ctr.result1.ToPopped[0].Vecs[i].GetType() + groupVecs[i] = vector.NewOffHeapVecWithType(vecType) + groupVecTypes[i] = vecType + } + } + + for _, bat := range group.ctr.result1.ToPopped { + if bat != nil && bat.RowCount() > 0 { + for i, vec := range bat.Vecs { + if i < len(groupVecs) && groupVecs[i] != nil && vec != nil { + if err := groupVecs[i].UnionBatch(vec, 0, vec.Length(), nil, proc.Mp()); err != nil { + logutil.Error("[SPILL] Group operator failed to union batch during spill", + zap.Int("vec_index", i), zap.Error(err)) + for j := range groupVecs { + if groupVecs[j] != nil { + groupVecs[j].Free(proc.Mp()) + } + } + return err + } + } + } + } + } + } + + spillData := &SpillableAggState{ + GroupVectors: groupVecs, + GroupVectorTypes: groupVecTypes, + MarshaledAggStates: marshaledAggStates, + GroupCount: totalGroups, + } + + spillID, err := group.SpillManager.Spill(spillData) + if err != nil { + logutil.Error("[SPILL] Group operator failed to spill data", zap.Error(err)) + spillData.Free(proc.Mp()) + return err + } + + logutil.Info("[SPILL] Group operator successfully spilled data", + zap.String("spill_id", string(spillID)), + zap.Int("total_groups", totalGroups), + zap.Int64("estimated_size", spillData.EstimateSize())) + + group.ctr.spilledStates = append(group.ctr.spilledStates, spillID) + + for _, agg := range group.ctr.result1.AggList { + if agg != nil { + agg.Free() + } + } + group.ctr.result1.AggList = nil + + for _, bat := range group.ctr.result1.ToPopped { + if bat != nil { + bat.Clean(proc.Mp()) + } + } + group.ctr.result1.ToPopped = group.ctr.result1.ToPopped[:0] + + if group.ctr.hr.Hash != nil { + group.ctr.hr.Hash.Free() + group.ctr.hr.Hash = nil + } + + group.ctr.currentMemUsage = 0 + logutil.Info("[SPILL] Group operator completed spill cleanup", + zap.Int("spilled_states_count", len(group.ctr.spilledStates))) + return nil +} + +func (group *Group) mergeSpilledResults(proc *process.Process) error { + if len(group.ctr.spilledStates) == 0 { + return nil + } + + logutil.Info("[SPILL] Group operator starting merge of spilled results", + zap.Int("spilled_states_count", len(group.ctr.spilledStates))) + + for i, spillID := range group.ctr.spilledStates { + logutil.Info("[SPILL] Group operator merging spilled state", + zap.Int("state_index", i), + zap.String("spill_id", string(spillID))) + + spillData, err := group.SpillManager.Retrieve(spillID, proc.Mp()) + if err != nil { + logutil.Error("[SPILL] Group operator failed to retrieve spilled data", + zap.String("spill_id", string(spillID)), zap.Error(err)) + return err + } + + spillState, ok := spillData.(*SpillableAggState) + if !ok { + logutil.Error("[SPILL] Group operator retrieved invalid spilled data type", + zap.String("spill_id", string(spillID))) + spillData.Free(proc.Mp()) + panic(fmt.Sprintf("invalid spilled data type")) + } + + logutil.Info("[SPILL] Group operator retrieved spilled state", + zap.String("spill_id", string(spillID)), + zap.Int("group_count", spillState.GroupCount), + zap.Int64("estimated_size", spillState.EstimateSize())) + + if err = group.restoreAndMergeSpilledAggregators(proc, spillState); err != nil { + logutil.Error("[SPILL] Group operator failed to restore and merge spilled aggregators", + zap.String("spill_id", string(spillID)), zap.Error(err)) + spillState.Free(proc.Mp()) + return err + } + + spillState.Free(proc.Mp()) + if err = group.SpillManager.Delete(spillID); err != nil { + logutil.Error("[SPILL] Group operator failed to delete spilled data", + zap.String("spill_id", string(spillID)), zap.Error(err)) + return err + } + + logutil.Info("[SPILL] Group operator completed merge of spilled state", + zap.String("spill_id", string(spillID))) + } + + logutil.Info("[SPILL] Group operator completed merge of all spilled results", + zap.Int("merged_states_count", len(group.ctr.spilledStates))) + + group.ctr.spilledStates = nil + return nil +} + +func (group *Group) restoreAndMergeSpilledAggregators(proc *process.Process, spillState *SpillableAggState) error { + if len(spillState.MarshaledAggStates) == 0 { + logutil.Info("[SPILL] Group operator restore found no marshaled aggregator states") + return nil + } + + logutil.Info("[SPILL] Group operator restoring spilled aggregators", + zap.Int("agg_states_count", len(spillState.MarshaledAggStates)), + zap.Int("group_count", spillState.GroupCount)) + + if len(group.ctr.result1.AggList) == 0 { + logutil.Info("[SPILL] Group operator initializing aggregators from spilled state") + aggs := make([]aggexec.AggFuncExec, len(spillState.MarshaledAggStates)) + defer func() { + if group.ctr.result1.AggList == nil { + for _, agg := range aggs { + if agg != nil { + agg.Free() + } + } + } + }() + + for i, marshaledState := range spillState.MarshaledAggStates { + if len(marshaledState) == 0 { + continue + } + + agg, err := aggexec.UnmarshalAggFuncExec(aggexec.NewSimpleAggMemoryManager(proc.Mp()), marshaledState) + if err != nil { + logutil.Error("[SPILL] Group operator failed to unmarshal aggregator", + zap.Int("agg_index", i), zap.Error(err)) + return err + } + + if i < len(group.Aggs) { + aggExpr := group.Aggs[i] + if config := aggExpr.GetExtraConfig(); config != nil { + if err = agg.SetExtraInformation(config, 0); err != nil { + logutil.Error("[SPILL] Group operator failed to set extra information for aggregator", + zap.Int("agg_index", i), zap.Error(err)) + agg.Free() + return err + } + } + } + + aggs[i] = agg + } + + chunkSize := aggexec.GetMinAggregatorsChunkSize(spillState.GroupVectors, aggs) + aggexec.SyncAggregatorsToChunkSize(aggs, chunkSize) + group.ctr.result1.ChunkSize = chunkSize + group.ctr.result1.AggList = aggs + + logutil.Info("[SPILL] Group operator initialized aggregators from spilled state", + zap.Int("chunk_size", chunkSize), + zap.Int("agg_count", len(aggs))) + + if len(spillState.GroupVectors) > 0 && spillState.GroupCount > 0 { + batchesToAdd := make([]*batch.Batch, 0) + + for offset := 0; offset < spillState.GroupCount; offset += chunkSize { + size := chunkSize + if offset+size > spillState.GroupCount { + size = spillState.GroupCount - offset + } + + bat := getInitialBatchWithSameTypeVecs(spillState.GroupVectors) + for i, vec := range spillState.GroupVectors { + if vec != nil && i < len(bat.Vecs) { + if err := bat.Vecs[i].UnionBatch(vec, int64(offset), size, nil, proc.Mp()); err != nil { + logutil.Error("[SPILL] Group operator failed to union batch during restore", + zap.Int("vec_index", i), zap.Int("offset", offset), zap.Error(err)) + bat.Clean(proc.Mp()) + for _, b := range batchesToAdd { + b.Clean(proc.Mp()) + } + return err + } + } + } + bat.SetRowCount(size) + batchesToAdd = append(batchesToAdd, bat) + } + group.ctr.result1.ToPopped = append(group.ctr.result1.ToPopped, batchesToAdd...) + } + + return nil + } + + logutil.Info("[SPILL] Group operator merging spilled aggregators with existing ones", + zap.Int("existing_agg_count", len(group.ctr.result1.AggList)), + zap.Int("spilled_group_count", spillState.GroupCount)) + + for _, currentAgg := range group.ctr.result1.AggList { + if currentAgg != nil { + if err := currentAgg.GroupGrow(spillState.GroupCount); err != nil { + logutil.Error("[SPILL] Group operator failed to grow aggregator groups", zap.Error(err)) + return err + } + } + } + + tempAggs := make([]aggexec.AggFuncExec, len(spillState.MarshaledAggStates)) + defer func() { + for _, agg := range tempAggs { + if agg != nil { + agg.Free() + } + } + }() + + for i, marshaledState := range spillState.MarshaledAggStates { + if len(marshaledState) == 0 { + continue + } + + agg, err := aggexec.UnmarshalAggFuncExec(aggexec.NewSimpleAggMemoryManager(proc.Mp()), marshaledState) + if err != nil { + logutil.Error("[SPILL] Group operator failed to unmarshal aggregator for merge", + zap.Int("agg_index", i), zap.Error(err)) + return err + } + + if i < len(group.Aggs) { + aggExpr := group.Aggs[i] + if config := aggExpr.GetExtraConfig(); config != nil { + if err = agg.SetExtraInformation(config, 0); err != nil { + logutil.Error("[SPILL] Group operator failed to set extra information for temp aggregator", + zap.Int("agg_index", i), zap.Error(err)) + agg.Free() + return err + } + } + } + + tempAggs[i] = agg + } + + currentGroupCount := 0 + for _, bat := range group.ctr.result1.ToPopped { + if bat != nil { + currentGroupCount += bat.RowCount() + } + } + + for i, tempAgg := range tempAggs { + if tempAgg == nil { + continue + } + + currentAgg := group.ctr.result1.AggList[i] + if currentAgg == nil { + continue + } + + for spilledGroupIdx := 0; spilledGroupIdx < spillState.GroupCount; spilledGroupIdx++ { + currentGroupIdx := currentGroupCount + spilledGroupIdx + if err := currentAgg.Merge(tempAgg, currentGroupIdx, spilledGroupIdx); err != nil { + logutil.Error("[SPILL] Group operator failed to merge aggregator groups", + zap.Int("agg_index", i), + zap.Int("current_group_idx", currentGroupIdx), + zap.Int("spilled_group_idx", spilledGroupIdx), + zap.Error(err)) + return err + } + } + } + + if len(spillState.GroupVectors) > 0 && spillState.GroupCount > 0 { + chunkSize := group.ctr.result1.ChunkSize + if chunkSize == 0 { + chunkSize = spillState.GroupCount + } + + batchesToAdd := make([]*batch.Batch, 0) + for offset := 0; offset < spillState.GroupCount; offset += chunkSize { + size := chunkSize + if offset+size > spillState.GroupCount { + size = spillState.GroupCount - offset + } + + bat := getInitialBatchWithSameTypeVecs(spillState.GroupVectors) + for i, vec := range spillState.GroupVectors { + if vec != nil && i < len(bat.Vecs) { + if err := bat.Vecs[i].UnionBatch(vec, int64(offset), size, nil, proc.Mp()); err != nil { + logutil.Error("[SPILL] Group operator failed to union batch during merge", + zap.Int("vec_index", i), zap.Int("offset", offset), zap.Error(err)) + bat.Clean(proc.Mp()) + for _, b := range batchesToAdd { + b.Clean(proc.Mp()) + } + return err + } + } + } + bat.SetRowCount(size) + batchesToAdd = append(batchesToAdd, bat) + } + group.ctr.result1.ToPopped = append(group.ctr.result1.ToPopped, batchesToAdd...) + } + + logutil.Info("[SPILL] Group operator completed restore and merge of spilled aggregators", + zap.Int("final_batch_count", len(group.ctr.result1.ToPopped))) + + return nil +} diff --git a/pkg/sql/colexec/group/spill.go b/pkg/sql/colexec/group/spill.go new file mode 100644 index 0000000000000..cda576ca787bc --- /dev/null +++ b/pkg/sql/colexec/group/spill.go @@ -0,0 +1,33 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import "github.com/matrixorigin/matrixone/pkg/common/mpool" + +type SpillID string + +type SpillableData interface { + Serialize() ([]byte, error) + Deserialize(data []byte, mp *mpool.MPool) error + EstimateSize() int64 + Free(mp *mpool.MPool) +} + +type SpillManager interface { + Spill(data SpillableData) (SpillID, error) + Retrieve(id SpillID, mp *mpool.MPool) (SpillableData, error) + Delete(id SpillID) error + Free() +} diff --git a/pkg/sql/colexec/group/spill_memory.go b/pkg/sql/colexec/group/spill_memory.go new file mode 100644 index 0000000000000..38e2c4a76d3a9 --- /dev/null +++ b/pkg/sql/colexec/group/spill_memory.go @@ -0,0 +1,134 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/matrixorigin/matrixone/pkg/common/mpool" + "github.com/matrixorigin/matrixone/pkg/logutil" + "go.uber.org/zap" +) + +type MemorySpillManager struct { + data map[SpillID][]byte + nextID int64 + totalMem int64 + mu sync.Mutex +} + +func NewMemorySpillManager() *MemorySpillManager { + return &MemorySpillManager{ + data: make(map[SpillID][]byte), + } +} + +func (m *MemorySpillManager) Spill(data SpillableData) (SpillID, error) { + serialized, err := data.Serialize() + if err != nil { + logutil.Error("[SPILL] MemorySpillManager failed to serialize data", zap.Error(err)) + return "", err + } + + id := SpillID(fmt.Sprintf("spill_%d", atomic.AddInt64(&m.nextID, 1))) + m.mu.Lock() + defer m.mu.Unlock() + m.data[id] = serialized + newTotalMem := atomic.AddInt64(&m.totalMem, int64(len(serialized))) + + logutil.Info("[SPILL] MemorySpillManager spilled data", + zap.String("spill_id", string(id)), + zap.Int("data_size", len(serialized)), + zap.Int64("total_memory", newTotalMem)) + + return id, nil +} + +func (m *MemorySpillManager) Retrieve(id SpillID, mp *mpool.MPool) (SpillableData, error) { + m.mu.Lock() + defer m.mu.Unlock() + + serialized, exists := m.data[id] + if !exists { + logutil.Error("[SPILL] MemorySpillManager failed to find spilled data", + zap.String("spill_id", string(id))) + return nil, fmt.Errorf("spill data not found: %s", id) + } + + logutil.Info("[SPILL] MemorySpillManager retrieving spilled data", + zap.String("spill_id", string(id)), + zap.Int("data_size", len(serialized))) + + data := &SpillableAggState{} + if err := data.Deserialize(serialized, mp); err != nil { + logutil.Error("[SPILL] MemorySpillManager failed to deserialize data", + zap.String("spill_id", string(id)), zap.Error(err)) + data.Free(mp) + return nil, err + } + + logutil.Info("[SPILL] MemorySpillManager successfully retrieved and deserialized data", + zap.String("spill_id", string(id)), + zap.Int64("estimated_size", data.EstimateSize())) + + return data, nil +} + +func (m *MemorySpillManager) Delete(id SpillID) error { + m.mu.Lock() + defer m.mu.Unlock() + + if serialized, exists := m.data[id]; exists { + newTotalMem := atomic.AddInt64(&m.totalMem, -int64(len(serialized))) + delete(m.data, id) + + logutil.Info("[SPILL] MemorySpillManager deleted spilled data", + zap.String("spill_id", string(id)), + zap.Int("data_size", len(serialized)), + zap.Int64("total_memory", newTotalMem)) + } else { + logutil.Warn("[SPILL] MemorySpillManager attempted to delete non-existent spilled data", + zap.String("spill_id", string(id))) + } + return nil +} + +func (m *MemorySpillManager) Free() { + m.mu.Lock() + defer m.mu.Unlock() + + count := len(m.data) + totalSize := atomic.LoadInt64(&m.totalMem) + + logutil.Info("[SPILL] MemorySpillManager freeing all spilled data", + zap.Int("spilled_count", count), + zap.Int64("total_size", totalSize)) + + for id := range m.data { + m.Delete(id) + } + m.data = nil + + logutil.Info("[SPILL] MemorySpillManager completed cleanup") +} + +func (m *MemorySpillManager) TotalMem() int64 { + m.mu.Lock() + defer m.mu.Unlock() + + return atomic.LoadInt64(&m.totalMem) +} diff --git a/pkg/sql/colexec/group/spill_memory_test.go b/pkg/sql/colexec/group/spill_memory_test.go new file mode 100644 index 0000000000000..eb488b60bab49 --- /dev/null +++ b/pkg/sql/colexec/group/spill_memory_test.go @@ -0,0 +1,87 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "testing" + + "github.com/matrixorigin/matrixone/pkg/testutil" + "github.com/stretchr/testify/require" +) + +func TestMemorySpillManager(t *testing.T) { + proc := testutil.NewProcess(t) + before := proc.Mp().CurrNB() + + manager := NewMemorySpillManager() + defer manager.Free() + + // Test spill and retrieve + data := &SpillableAggState{ + GroupCount: 10, + MarshaledAggStates: [][]byte{ + []byte("test_agg_state_1"), + []byte("test_agg_state_2"), + }, + } + + id, err := manager.Spill(data) + require.NoError(t, err) + require.NotEmpty(t, id) + + retrieved, err := manager.Retrieve(id, proc.Mp()) + require.NoError(t, err) + require.NotNil(t, retrieved) + + spillState, ok := retrieved.(*SpillableAggState) + require.True(t, ok) + require.Equal(t, int(10), spillState.GroupCount) + require.Equal(t, 2, len(spillState.MarshaledAggStates)) + + // Test delete + err = manager.Delete(id) + require.NoError(t, err) + + _, err = manager.Retrieve(id, proc.Mp()) + require.Error(t, err) + + // Test memory accounting + require.Equal(t, int64(0), manager.TotalMem()) + + // Test with larger data + largeData := &SpillableAggState{ + GroupCount: 1000, + MarshaledAggStates: make([][]byte, 100), + } + for i := range largeData.MarshaledAggStates { + largeData.MarshaledAggStates[i] = make([]byte, 1024) // 1KB per state + } + + id, err = manager.Spill(largeData) + require.NoError(t, err) + + // Verify memory usage increased + memAfterSpill := manager.TotalMem() + require.Greater(t, memAfterSpill, int64(0)) + + // Clean up + err = manager.Delete(id) + require.NoError(t, err) + + require.Equal(t, int64(0), manager.TotalMem()) + + after := proc.Mp().CurrNB() + require.Equal(t, before, after, "Memory leak detected") +} diff --git a/pkg/sql/colexec/group/spill_test.go b/pkg/sql/colexec/group/spill_test.go new file mode 100644 index 0000000000000..556178defa70f --- /dev/null +++ b/pkg/sql/colexec/group/spill_test.go @@ -0,0 +1,15 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group diff --git a/pkg/sql/colexec/group/spillable_agg_state.go b/pkg/sql/colexec/group/spillable_agg_state.go new file mode 100644 index 0000000000000..6ed25d2ad92cc --- /dev/null +++ b/pkg/sql/colexec/group/spillable_agg_state.go @@ -0,0 +1,202 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "bytes" + "encoding/binary" + + "github.com/matrixorigin/matrixone/pkg/common/mpool" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" +) + +type SpillableAggState struct { + GroupVectors []*vector.Vector + GroupVectorTypes []types.Type + MarshaledAggStates [][]byte + GroupCount int +} + +func (s *SpillableAggState) Serialize() ([]byte, error) { + + buf := bytes.NewBuffer(nil) + + if err := binary.Write(buf, binary.LittleEndian, int32(s.GroupCount)); err != nil { + return nil, err + } + + if err := binary.Write(buf, binary.LittleEndian, int32(len(s.GroupVectors))); err != nil { + return nil, err + } + + if err := binary.Write(buf, binary.LittleEndian, int32(len(s.GroupVectorTypes))); err != nil { + return nil, err + } + for _, typ := range s.GroupVectorTypes { + typBytes, err := typ.MarshalBinary() + if err != nil { + return nil, err + } + if err := binary.Write(buf, binary.LittleEndian, int32(len(typBytes))); err != nil { + return nil, err + } + if _, err := buf.Write(typBytes); err != nil { + return nil, err + } + } + + for _, vec := range s.GroupVectors { + if vec == nil { + if err := binary.Write(buf, binary.LittleEndian, int32(0)); err != nil { + return nil, err + } + continue + } + + vecBytes, err := vec.MarshalBinary() + if err != nil { + return nil, err + } + if err := binary.Write(buf, binary.LittleEndian, int32(len(vecBytes))); err != nil { + return nil, err + } + if _, err := buf.Write(vecBytes); err != nil { + return nil, err + } + } + + if err := binary.Write(buf, binary.LittleEndian, int32(len(s.MarshaledAggStates))); err != nil { + return nil, err + } + for _, aggState := range s.MarshaledAggStates { + if err := binary.Write(buf, binary.LittleEndian, int32(len(aggState))); err != nil { + return nil, err + } + if _, err := buf.Write(aggState); err != nil { + return nil, err + } + } + + result := buf.Bytes() + + return result, nil +} + +func (s *SpillableAggState) Deserialize(data []byte, mp *mpool.MPool) error { + + buf := bytes.NewReader(data) + + var groupCount int32 + if err := binary.Read(buf, binary.LittleEndian, &groupCount); err != nil { + return err + } + s.GroupCount = int(groupCount) + + var groupVecCount int32 + if err := binary.Read(buf, binary.LittleEndian, &groupVecCount); err != nil { + return err + } + + var groupVecTypeCount int32 + if err := binary.Read(buf, binary.LittleEndian, &groupVecTypeCount); err != nil { + return err + } + s.GroupVectorTypes = make([]types.Type, groupVecTypeCount) + for i := 0; i < int(groupVecTypeCount); i++ { + var size int32 + if err := binary.Read(buf, binary.LittleEndian, &size); err != nil { + return err + } + typBytes := make([]byte, size) + if _, err := buf.Read(typBytes); err != nil { + return err + } + if err := s.GroupVectorTypes[i].UnmarshalBinary(typBytes); err != nil { + return err + } + } + + s.GroupVectors = make([]*vector.Vector, groupVecCount) + for i := 0; i < int(groupVecCount); i++ { + var size int32 + if err := binary.Read(buf, binary.LittleEndian, &size); err != nil { + return err + } + if size == 0 { + s.GroupVectors[i] = nil + continue + } + + vecBytes := make([]byte, size) + if _, err := buf.Read(vecBytes); err != nil { + return err + } + + vecType := types.T_any.ToType() + if i < len(s.GroupVectorTypes) { + vecType = s.GroupVectorTypes[i] + } + + vec := vector.NewOffHeapVecWithType(vecType) + if err := vec.UnmarshalBinaryWithCopy(vecBytes, mp); err != nil { + vec.Free(mp) + return err + } + s.GroupVectors[i] = vec + } + + var aggStateCount int32 + if err := binary.Read(buf, binary.LittleEndian, &aggStateCount); err != nil { + return err + } + s.MarshaledAggStates = make([][]byte, aggStateCount) + for i := 0; i < int(aggStateCount); i++ { + var size int32 + if err := binary.Read(buf, binary.LittleEndian, &size); err != nil { + return err + } + s.MarshaledAggStates[i] = make([]byte, size) + if _, err := buf.Read(s.MarshaledAggStates[i]); err != nil { + return err + } + } + + return nil +} + +func (s *SpillableAggState) EstimateSize() int64 { + size := int64(0) + for _, vec := range s.GroupVectors { + if vec != nil { + size += int64(vec.Allocated()) + } + } + for _, aggState := range s.MarshaledAggStates { + size += int64(len(aggState)) + } + return size +} + +func (s *SpillableAggState) Free(mp *mpool.MPool) { + for _, vec := range s.GroupVectors { + if vec != nil { + vec.Free(mp) + } + } + s.GroupVectors = nil + s.GroupVectorTypes = nil + s.MarshaledAggStates = nil +} diff --git a/pkg/sql/colexec/group/testspill/spill_test.go b/pkg/sql/colexec/group/testspill/spill_test.go new file mode 100644 index 0000000000000..33b8db14cb896 --- /dev/null +++ b/pkg/sql/colexec/group/testspill/spill_test.go @@ -0,0 +1,323 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testspill + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/matrixorigin/matrixone/pkg/embed" + "github.com/stretchr/testify/require" +) + +func TestGroupSpillLargeGroups(t *testing.T) { + cluster, err := embed.NewCluster( + embed.WithCNCount(1), + embed.WithTesting(), + embed.WithPreStart(func(service embed.ServiceOperator) { + service.Adjust(func(config *embed.ServiceConfig) { + config.Log.Level = "debug" + }) + }), + ) + require.NoError(t, err) + err = cluster.Start() + require.NoError(t, err) + defer cluster.Close() + + cn0, err := cluster.GetCNService(0) + require.NoError(t, err) + dsn := fmt.Sprintf("dump:111@tcp(127.0.0.1:%d)/", + cn0.GetServiceConfig().CN.Frontend.Port, + ) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(`CREATE DATABASE IF NOT EXISTS test_spill`) + require.NoError(t, err) + + _, err = db.Exec(`USE test_spill`) + require.NoError(t, err) + + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS large_group_test ( + id BIGINT PRIMARY KEY, + group_col1 INT, + group_col2 VARCHAR(100), + value_col BIGINT + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO large_group_test (id, group_col1, group_col2, value_col) + SELECT + g.result as id, + FLOOR(RAND() * 100000) as group_col1, + CONCAT('group_', FLOOR(RAND() * 50000)) as group_col2, + FLOOR(RAND() * 1000) as value_col + FROM generate_series(1000000) g + `) + require.NoError(t, err) + + rows, err := db.Query(` + SELECT + group_col1, + group_col2, + COUNT(*) as cnt, + SUM(value_col) as sum_val + FROM large_group_test + GROUP BY group_col1, group_col2 + ORDER BY cnt DESC + LIMIT 10 + `) + require.NoError(t, err) + defer rows.Close() + + count := 0 + for rows.Next() { + count++ + } + require.NoError(t, rows.Err()) + t.Logf("Successfully processed %d results from spilled group by operation", count) +} + +func TestGroupSpillWithStrings(t *testing.T) { + cluster, err := embed.NewCluster( + embed.WithCNCount(1), + embed.WithTesting(), + embed.WithPreStart(func(service embed.ServiceOperator) { + service.Adjust(func(config *embed.ServiceConfig) { + config.Log.Level = "debug" + }) + }), + ) + require.NoError(t, err) + err = cluster.Start() + require.NoError(t, err) + defer cluster.Close() + + cn0, err := cluster.GetCNService(0) + require.NoError(t, err) + dsn := fmt.Sprintf("dump:111@tcp(127.0.0.1:%d)/", + cn0.GetServiceConfig().CN.Frontend.Port, + ) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(`CREATE DATABASE IF NOT EXISTS test_spill`) + require.NoError(t, err) + + _, err = db.Exec(`USE test_spill`) + require.NoError(t, err) + + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS string_group_test ( + id BIGINT PRIMARY KEY, + category VARCHAR(50), + description TEXT + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO string_group_test (id, category, description) + SELECT + g.result as id, + CONCAT('category_', FLOOR(RAND() * 10000)) as category, -- 10k categories + CONCAT('description_', g.result, '_long_text_data_for_spill_testing') as description + FROM generate_series(500000) g -- 500k rows + `) + require.NoError(t, err) + + rows, err := db.Query(` + SELECT + category, + GROUP_CONCAT(description ORDER BY id SEPARATOR ', ') as concatenated_desc, + COUNT(*) as cnt + FROM string_group_test + GROUP BY category + HAVING cnt > 1 + ORDER BY cnt DESC + LIMIT 5 + `) + require.NoError(t, err) + defer rows.Close() + + count := 0 + for rows.Next() { + count++ + } + require.NoError(t, rows.Err()) + t.Logf("Successfully processed %d results from string group concat spill operation", count) +} + +func TestGroupSpillApproxCountDistinct(t *testing.T) { + cluster, err := embed.NewCluster( + embed.WithCNCount(1), + embed.WithTesting(), + embed.WithPreStart(func(service embed.ServiceOperator) { + service.Adjust(func(config *embed.ServiceConfig) { + config.Log.Level = "debug" + }) + }), + ) + require.NoError(t, err) + err = cluster.Start() + require.NoError(t, err) + defer cluster.Close() + + cn0, err := cluster.GetCNService(0) + require.NoError(t, err) + dsn := fmt.Sprintf("dump:111@tcp(127.0.0.1:%d)/", + cn0.GetServiceConfig().CN.Frontend.Port, + ) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(`CREATE DATABASE IF NOT EXISTS test_spill`) + require.NoError(t, err) + + _, err = db.Exec(`USE test_spill`) + require.NoError(t, err) + + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS approx_count_test ( + id BIGINT PRIMARY KEY, + group_col VARCHAR(100), + value_col BIGINT + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO approx_count_test (id, group_col, value_col) + SELECT + g.result as id, + CONCAT('group_', FLOOR(RAND() * 5000)) as group_col, -- 5k groups + FLOOR(RAND() * 1000000) as value_col -- many distinct values per group + FROM generate_series(800000) g -- 800k rows + `) + require.NoError(t, err) + + rows, err := db.Query(` + SELECT + group_col, + APPROX_COUNT_DISTINCT(value_col) as approx_distinct_count, + COUNT(*) as total_count + FROM approx_count_test + GROUP BY group_col + ORDER BY approx_distinct_count DESC + LIMIT 10 + `) + require.NoError(t, err) + defer rows.Close() + + count := 0 + for rows.Next() { + count++ + } + require.NoError(t, rows.Err()) + t.Logf("Successfully processed %d results from approx count distinct spill operation", count) +} + +func TestGroupSpillMixedAggregations(t *testing.T) { + cluster, err := embed.NewCluster( + embed.WithCNCount(1), + embed.WithTesting(), + embed.WithPreStart(func(service embed.ServiceOperator) { + service.Adjust(func(config *embed.ServiceConfig) { + config.Log.Level = "debug" + }) + }), + ) + require.NoError(t, err) + err = cluster.Start() + require.NoError(t, err) + defer cluster.Close() + + cn0, err := cluster.GetCNService(0) + require.NoError(t, err) + dsn := fmt.Sprintf("dump:111@tcp(127.0.0.1:%d)/", + cn0.GetServiceConfig().CN.Frontend.Port, + ) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(`CREATE DATABASE IF NOT EXISTS test_spill`) + require.NoError(t, err) + + _, err = db.Exec(`USE test_spill`) + require.NoError(t, err) + + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS mixed_agg_test ( + id BIGINT PRIMARY KEY, + dept VARCHAR(50), + product VARCHAR(100), + sales_amount DECIMAL(15,2), + quantity INT, + sale_date DATE + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO mixed_agg_test (id, dept, product, sales_amount, quantity, sale_date) + SELECT + g.result as id, + CONCAT('dept_', FLOOR(RAND() * 2000)) as dept, -- 2k departments + CONCAT('product_', FLOOR(RAND() * 50000)) as product, -- 50k products + RAND() * 1000 as sales_amount, + FLOOR(RAND() * 100) as quantity, + DATE_ADD('2023-01-01', INTERVAL FLOOR(RAND() * 365) DAY) as sale_date + FROM generate_series(600000) g -- 600k rows + `) + require.NoError(t, err) + + rows, err := db.Query(` + SELECT + dept, + product, + COUNT(*) as transaction_count, + SUM(sales_amount) as total_sales, + AVG(sales_amount) as avg_sales, + MIN(sales_amount) as min_sales, + MAX(sales_amount) as max_sales, + SUM(quantity) as total_quantity + FROM mixed_agg_test + GROUP BY dept, product + ORDER BY total_sales DESC + LIMIT 10 + `) + require.NoError(t, err) + defer rows.Close() + + count := 0 + for rows.Next() { + count++ + } + require.NoError(t, rows.Err()) + t.Logf("Successfully processed %d results from mixed aggregations spill operation", count) +} diff --git a/pkg/sql/colexec/group/types.go b/pkg/sql/colexec/group/types.go index 0243e4cda69d2..ae3e713963a05 100644 --- a/pkg/sql/colexec/group/types.go +++ b/pkg/sql/colexec/group/types.go @@ -93,6 +93,10 @@ type Group struct { GroupingFlag []bool // agg info and agg column. Aggs []aggexec.AggFuncExecExpression + + // spill configuration + SpillManager SpillManager + SpillThreshold int64 } func (group *Group) evaluateGroupByAndAgg(proc *process.Process, bat *batch.Batch) (err error) { @@ -161,6 +165,11 @@ type container struct { result1 GroupResultBuffer // result if NeedEval is false. result2 GroupResultNoneBlock + + // spill state + currentMemUsage int64 + spilledStates []SpillID + spillPending bool } func (ctr *container) isDataSourceEmpty() bool { @@ -173,6 +182,11 @@ func (group *Group) Free(proc *process.Process, _ bool, _ error) { group.ctr.freeGroupEvaluate() group.ctr.freeAggEvaluate() group.FreeProjection(proc) + + if group.SpillManager != nil { + group.SpillManager.Free() + group.SpillManager = nil + } } func (group *Group) Reset(proc *process.Process, pipelineFailed bool, err error) { @@ -189,6 +203,13 @@ func (group *Group) freeCannotReuse(mp *mpool.MPool) { group.ctr.hr.Free0() group.ctr.result1.Free0(mp) group.ctr.result2.Free0(mp) + + for _, id := range group.ctr.spilledStates { + if group.SpillManager != nil { + group.SpillManager.Delete(id) + } + } + group.ctr.spilledStates = nil } func (ctr *container) freeAggEvaluate() {