Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions pkg/scheduler/plugins/topology/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2025 NVIDIA CORPORATION
// SPDX-License-Identifier: Apache-2.0

package topology

import (
"crypto/sha256"
"encoding/hex"
"slices"
"strings"

"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
"github.com/samber/lo"
)

type nodeSetID = string

// Function that accepts a node set (list of nodes) and return an identifier for the node set that can
// be used as a key in a map
func getNodeSetID(nodeSet node_info.NodeSet) nodeSetID {
nodeNames := lo.Map(nodeSet, func(node *node_info.NodeInfo, _ int) string {
return node.Name
})
slices.Sort(nodeNames)
concatenated := strings.Join(nodeNames, ",")

hash := sha256.Sum256([]byte(concatenated))
return nodeSetID(hex.EncodeToString(hash[:]))
}
9 changes: 7 additions & 2 deletions pkg/scheduler/plugins/topology/job_filtering.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,21 @@ func (t *topologyPlugin) subSetNodesFn(
return []node_info.NodeSet{}, nil
}

domain, ok := t.nodeSetToDomain[topologyTree.Name][getNodeSetID(nodeSet)]
if !ok {
return nil, fmt.Errorf("domain not found for node set in topology %s", topologyTree.Name)
}

// Sorting the tree for both packing and closest preferred level domain scoring
preferredLevel := DomainLevel(subGroup.GetTopologyConstraint().PreferredLevel)
requiredLevel := DomainLevel(subGroup.GetTopologyConstraint().RequiredLevel)
maxDepthLevel := preferredLevel
if maxDepthLevel == "" {
maxDepthLevel = requiredLevel
}
sortTree(topologyTree.DomainsByLevel[rootLevel][rootDomainId], maxDepthLevel)
sortTree(domain, maxDepthLevel)
if preferredLevel != "" {
t.subGroupNodeScores[subGroup.GetName()] = calculateNodeScores(topologyTree.DomainsByLevel[rootLevel][rootDomainId], preferredLevel)
t.subGroupNodeScores[subGroup.GetName()] = calculateNodeScores(domain, preferredLevel)
}

jobAllocatableDomains, err := t.getJobAllocatableDomains(job, subGroup, podSets, len(tasks), topologyTree)
Expand Down
15 changes: 15 additions & 0 deletions pkg/scheduler/plugins/topology/job_filtering_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sort"
"testing"

"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/maps"
v1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -394,12 +395,26 @@ func TestTopologyPlugin_subsetNodesFn(t *testing.T) {
}
}

// Setup nodeSetToDomain mapping
nodeSetToDomain := map[topologyName]map[nodeSetID]*DomainInfo{}
nodeSetToDomain[topologyTree.Name] = map[nodeSetID]*DomainInfo{}
domains := []*DomainInfo{}
for _, levelDomains := range topologyTree.DomainsByLevel {
for _, domain := range levelDomains {
domains = append(domains, domain)
}
}
for _, domain := range domains {
nodeSetToDomain[topologyTree.Name][getNodeSetID(lo.Values(domain.Nodes))] = domain
}

// Setup plugin
plugin := &topologyPlugin{
TopologyTrees: map[string]*Info{
"test-topology": topologyTree,
},
subGroupNodeScores: map[subgroupName]map[string]float64{},
nodeSetToDomain: nodeSetToDomain,
}

// Call the function under test
Expand Down
22 changes: 21 additions & 1 deletion pkg/scheduler/plugins/topology/topology_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/framework"
"github.com/samber/lo"
)

const (
Expand All @@ -25,13 +26,15 @@ type topologyPlugin struct {

// Defines order among nodes in a sub-group based on the sub-group's preferred level topology constraint.
subGroupNodeScores map[subgroupName]map[string]float64
nodeSetToDomain map[topologyName]map[nodeSetID]*DomainInfo
session *framework.Session
}

func New(_ map[string]string) framework.Plugin {
return &topologyPlugin{
TopologyTrees: map[topologyName]*Info{},
subGroupNodeScores: map[subgroupName]map[string]float64{},
nodeSetToDomain: map[topologyName]map[nodeSetID]*DomainInfo{},
session: nil,
}
}
Expand All @@ -41,8 +44,8 @@ func (t *topologyPlugin) Name() string {
}

func (t *topologyPlugin) OnSessionOpen(ssn *framework.Session) {
t.initializeTopologyTree(ssn.Topologies, ssn.Nodes)
t.session = ssn
t.initializeTopologyTree(ssn.Topologies, ssn.Nodes)

ssn.AddSubsetNodesFn(t.subSetNodesFn)
ssn.AddNodeOrderFn(t.nodeOrderFn)
Expand Down Expand Up @@ -71,9 +74,26 @@ func (t *topologyPlugin) initializeTopologyTree(topologies []*kueuev1alpha1.Topo
}

t.TopologyTrees[topology.Name] = topologyTree

t.buildNodeSetToDomainMapping(topology.Name, topologyTree)
}
}

func (t *topologyPlugin) buildNodeSetToDomainMapping(topologyName topologyName, topologyTree *Info) {
t.nodeSetToDomain[topologyName] = map[nodeSetID]*DomainInfo{}
domains := []*DomainInfo{}
for _, levelDomains := range topologyTree.DomainsByLevel {
for _, domain := range levelDomains {
domains = append(domains, domain)
}
}
for _, domain := range domains {
t.nodeSetToDomain[topologyName][getNodeSetID(lo.Values(domain.Nodes))] = domain
}

t.nodeSetToDomain[topologyName][getNodeSetID(lo.Values(t.session.Nodes))] = topologyTree.DomainsByLevel[rootLevel][rootDomainId]
}

func (*topologyPlugin) addNodeDataToTopology(topologyTree *Info, topology *kueuev1alpha1.Topology, nodeInfo *node_info.NodeInfo) {
// Validate that the node is part of the topology
if !isNodePartOfTopology(nodeInfo, topology) {
Expand Down
Loading