diff --git a/pkg/scheduler/plugins/topology/common.go b/pkg/scheduler/plugins/topology/common.go new file mode 100644 index 000000000..8b31f15e3 --- /dev/null +++ b/pkg/scheduler/plugins/topology/common.go @@ -0,0 +1,29 @@ +// Copyright 2025 NVIDIA CORPORATION +// SPDX-License-Identifier: Apache-2.0 + +package topology + +import ( + "crypto/sha256" + "encoding/hex" + "slices" + "strings" + + "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info" + "github.com/samber/lo" +) + +type nodeSetID = string + +// Function that accepts a node set (list of nodes) and return an identifier for the node set that can +// be used as a key in a map +func getNodeSetID(nodeSet node_info.NodeSet) nodeSetID { + nodeNames := lo.Map(nodeSet, func(node *node_info.NodeInfo, _ int) string { + return node.Name + }) + slices.Sort(nodeNames) + concatenated := strings.Join(nodeNames, ",") + + hash := sha256.Sum256([]byte(concatenated)) + return nodeSetID(hex.EncodeToString(hash[:])) +} diff --git a/pkg/scheduler/plugins/topology/job_filtering.go b/pkg/scheduler/plugins/topology/job_filtering.go index be09896bc..d320f03d0 100644 --- a/pkg/scheduler/plugins/topology/job_filtering.go +++ b/pkg/scheduler/plugins/topology/job_filtering.go @@ -53,6 +53,11 @@ func (t *topologyPlugin) subSetNodesFn( return []node_info.NodeSet{}, nil } + domain, ok := t.nodeSetToDomain[topologyTree.Name][getNodeSetID(nodeSet)] + if !ok { + return nil, fmt.Errorf("domain not found for node set in topology %s", topologyTree.Name) + } + // Sorting the tree for both packing and closest preferred level domain scoring preferredLevel := DomainLevel(subGroup.GetTopologyConstraint().PreferredLevel) requiredLevel := DomainLevel(subGroup.GetTopologyConstraint().RequiredLevel) @@ -60,9 +65,9 @@ func (t *topologyPlugin) subSetNodesFn( if maxDepthLevel == "" { maxDepthLevel = requiredLevel } - sortTree(topologyTree.DomainsByLevel[rootLevel][rootDomainId], maxDepthLevel) + sortTree(domain, maxDepthLevel) if preferredLevel != "" { - t.subGroupNodeScores[subGroup.GetName()] = calculateNodeScores(topologyTree.DomainsByLevel[rootLevel][rootDomainId], preferredLevel) + t.subGroupNodeScores[subGroup.GetName()] = calculateNodeScores(domain, preferredLevel) } jobAllocatableDomains, err := t.getJobAllocatableDomains(job, subGroup, podSets, len(tasks), topologyTree) diff --git a/pkg/scheduler/plugins/topology/job_filtering_test.go b/pkg/scheduler/plugins/topology/job_filtering_test.go index ade5274b7..1e30d0f93 100644 --- a/pkg/scheduler/plugins/topology/job_filtering_test.go +++ b/pkg/scheduler/plugins/topology/job_filtering_test.go @@ -8,6 +8,7 @@ import ( "sort" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "golang.org/x/exp/maps" v1 "k8s.io/api/core/v1" @@ -394,12 +395,26 @@ func TestTopologyPlugin_subsetNodesFn(t *testing.T) { } } + // Setup nodeSetToDomain mapping + nodeSetToDomain := map[topologyName]map[nodeSetID]*DomainInfo{} + nodeSetToDomain[topologyTree.Name] = map[nodeSetID]*DomainInfo{} + domains := []*DomainInfo{} + for _, levelDomains := range topologyTree.DomainsByLevel { + for _, domain := range levelDomains { + domains = append(domains, domain) + } + } + for _, domain := range domains { + nodeSetToDomain[topologyTree.Name][getNodeSetID(lo.Values(domain.Nodes))] = domain + } + // Setup plugin plugin := &topologyPlugin{ TopologyTrees: map[string]*Info{ "test-topology": topologyTree, }, subGroupNodeScores: map[subgroupName]map[string]float64{}, + nodeSetToDomain: nodeSetToDomain, } // Call the function under test diff --git a/pkg/scheduler/plugins/topology/topology_plugin.go b/pkg/scheduler/plugins/topology/topology_plugin.go index 006657d7b..336c264dc 100644 --- a/pkg/scheduler/plugins/topology/topology_plugin.go +++ b/pkg/scheduler/plugins/topology/topology_plugin.go @@ -9,6 +9,7 @@ import ( "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info" "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info" "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/framework" + "github.com/samber/lo" ) const ( @@ -25,6 +26,7 @@ type topologyPlugin struct { // Defines order among nodes in a sub-group based on the sub-group's preferred level topology constraint. subGroupNodeScores map[subgroupName]map[string]float64 + nodeSetToDomain map[topologyName]map[nodeSetID]*DomainInfo session *framework.Session } @@ -32,6 +34,7 @@ func New(_ framework.PluginArguments) framework.Plugin { return &topologyPlugin{ TopologyTrees: map[topologyName]*Info{}, subGroupNodeScores: map[subgroupName]map[string]float64{}, + nodeSetToDomain: map[topologyName]map[nodeSetID]*DomainInfo{}, session: nil, } } @@ -41,8 +44,8 @@ func (t *topologyPlugin) Name() string { } func (t *topologyPlugin) OnSessionOpen(ssn *framework.Session) { - t.initializeTopologyTree(ssn.Topologies, ssn.Nodes) t.session = ssn + t.initializeTopologyTree(ssn.Topologies, ssn.Nodes) ssn.AddSubsetNodesFn(t.subSetNodesFn) ssn.AddNodeOrderFn(t.nodeOrderFn) @@ -71,9 +74,26 @@ func (t *topologyPlugin) initializeTopologyTree(topologies []*kueuev1alpha1.Topo } t.TopologyTrees[topology.Name] = topologyTree + + t.buildNodeSetToDomainMapping(topology.Name, topologyTree) } } +func (t *topologyPlugin) buildNodeSetToDomainMapping(topologyName topologyName, topologyTree *Info) { + t.nodeSetToDomain[topologyName] = map[nodeSetID]*DomainInfo{} + domains := []*DomainInfo{} + for _, levelDomains := range topologyTree.DomainsByLevel { + for _, domain := range levelDomains { + domains = append(domains, domain) + } + } + for _, domain := range domains { + t.nodeSetToDomain[topologyName][getNodeSetID(lo.Values(domain.Nodes))] = domain + } + + t.nodeSetToDomain[topologyName][getNodeSetID(lo.Values(t.session.Nodes))] = topologyTree.DomainsByLevel[rootLevel][rootDomainId] +} + func (*topologyPlugin) addNodeDataToTopology(topologyTree *Info, topology *kueuev1alpha1.Topology, nodeInfo *node_info.NodeInfo) { // Validate that the node is part of the topology if !isNodePartOfTopology(nodeInfo, topology) {