Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions deployments/kai-scheduler/templates/rbac/podgrouper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,22 @@ rules:
- create
- patch
- update
- apiGroups:
- leaderworkerset.x-k8s.io
resources:
- leaderworkersets
verbs:
- get
- list
- watch
- apiGroups:
- leaderworkerset.x-k8s.io
resources:
- leaderworkersets/finalizers
verbs:
- create
- patch
- update
- apiGroups:
- machinelearning.seldon.io
resources:
Expand Down Expand Up @@ -202,6 +218,7 @@ rules:
- interactiveworkloads
- runaijobs
- trainingworkloads
- distributedinferenceworkload
verbs:
- get
- list
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ require (
knative.dev/serving v0.44.0
sigs.k8s.io/controller-runtime v0.20.0
sigs.k8s.io/karpenter v1.2.0
sigs.k8s.io/lws v0.5.1
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,8 @@ sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 h1:gBQPwqORJ8d8/YNZWEjoZs7np
sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
sigs.k8s.io/karpenter v1.2.0 h1:y1zyFGzDLiT5OHpG8Jvj4JKKe/iXmJDYeejs8k8OznM=
sigs.k8s.io/karpenter v1.2.0/go.mod h1:646txj32arNTy+K4gySCqWSljYrEdemAdYoBMQmkS7o=
sigs.k8s.io/lws v0.5.1 h1:eaeMNkP0manRluQZLN32atoULaGrzP611gSLdFaHZs4=
sigs.k8s.io/lws v0.5.1/go.mod h1:qprXSTTFnfmPZY3V3sUfk6ZPmAodsdoKS8XVElJ9kN0=
sigs.k8s.io/structured-merge-diff/v4 v4.5.0 h1:nbCitCK2hfnhyiKo6uf2HxUPTCodY6Qaf85SbDIaMBk=
sigs.k8s.io/structured-merge-diff/v4 v4.5.0/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4=
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
Expand Down
1 change: 1 addition & 0 deletions hack/run-e2e-kind.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ if [ "$TEST_THIRD_PARTY_INTEGRATIONS" = "true" ]; then
${REPO_ROOT}/hack/third_party_integrations/deploy_ray.sh
${REPO_ROOT}/hack/third_party_integrations/deploy_kubeflow.sh
${REPO_ROOT}/hack/third_party_integrations/deploy_knative.sh
${REPO_ROOT}/hack/third_party_integrations/deploy_lws.sh
fi

PACKAGE_VERSION=0.0.0-$(git rev-parse --short origin/main)
Expand Down
7 changes: 7 additions & 0 deletions hack/third_party_integrations/deploy_lws.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
# Copyright 2025 NVIDIA CORPORATION
# SPDX-License-Identifier: Apache-2.0
set -e

CHART_VERSION=0.6.1
helm install lws oci://registry.k8s.io/lws/charts/lws --version=$CHART_VERSION --namespace lws-system --create-namespace --wait --timeout 300s
25 changes: 18 additions & 7 deletions pkg/podgrouper/podgrouper/hub/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
pytorchplugin "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/kubeflow/pytorch"
tensorflowlugin "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/kubeflow/tensorflow"
xgboostplugin "github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/kubeflow/xgboost"
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/lws"
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/podjob"
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/ray"
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/runaijob"
Expand All @@ -30,12 +31,13 @@ import (
)

const (
apiGroupArgo = "argoproj.io"
apiGroupRunai = "run.ai"
kindTrainingWorkload = "TrainingWorkload"
kindInteractiveWorkload = "InteractiveWorkload"
kindDistributedWorkload = "DistributedWorkload"
kindInferenceWorkload = "InferenceWorkload"
apiGroupArgo = "argoproj.io"
apiGroupRunai = "run.ai"
kindTrainingWorkload = "TrainingWorkload"
kindInteractiveWorkload = "InteractiveWorkload"
kindDistributedWorkload = "DistributedWorkload"
kindInferenceWorkload = "InferenceWorkload"
kindDistributedInferenceWorkload = "DistributedInferenceWorkload"
)

// +kubebuilder:rbac:groups=apps,resources=replicasets;statefulsets,verbs=get;list;watch
Expand All @@ -51,6 +53,8 @@ const (
// +kubebuilder:rbac:groups=tekton.dev,resources=pipelineruns;taskruns,verbs=get;list;watch
// +kubebuilder:rbac:groups=tekton.dev,resources=pipelineruns/finalizers;taskruns/finalizers,verbs=patch;update;create
// +kubebuilder:rbac:groups=run.ai,resources=trainingworkloads;interactiveworkloads;distributedworkloads;inferenceworkloads,verbs=get;list;watch
// +kubebuilder:rbac:groups=leaderworkerset.x-k8s.io,resources=leaderworkersets,verbs=get;list;watch
// +kubebuilder:rbac:groups=leaderworkerset.x-k8s.io,resources=leaderworkersets/finalizers,verbs=patch;update;create

type PluginsHub struct {
defaultPlugin *defaultgrouper.DefaultGrouper
Expand All @@ -76,6 +80,7 @@ func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,

kubeFlowDistributedGrouper := kubeflow.NewKubeflowDistributedGrouper(defaultGrouper)
mpiGrouper := mpi.NewMpiGrouper(kubeClient, kubeFlowDistributedGrouper)
lwsGrouper := lws.NewLwsGrouper(defaultGrouper)

rayGrouper := ray.NewRayGrouper(kubeClient, defaultGrouper)
rayClusterGrouper := ray.NewRayClusterGrouper(rayGrouper)
Expand Down Expand Up @@ -236,6 +241,11 @@ func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,
Version: "v1",
Kind: "SPOTRequest",
}: spotrequest.NewSpotRequestGrouper(defaultGrouper),
{
Group: "leaderworkerset.x-k8s.io",
Version: "v1",
Kind: "LeaderWorkerSet",
}: lwsGrouper,
}

skipTopOwnerGrouper := skiptopowner.NewSkipTopOwnerGrouper(kubeClient, defaultGrouper, table)
Expand All @@ -245,7 +255,8 @@ func NewPluginsHub(kubeClient client.Client, searchForLegacyPodGroups,
Kind: "Workflow",
}] = skipTopOwnerGrouper

for _, kind := range []string{kindInferenceWorkload, kindTrainingWorkload, kindDistributedWorkload, kindInteractiveWorkload} {
for _, kind := range []string{
kindInferenceWorkload, kindTrainingWorkload, kindDistributedWorkload, kindInteractiveWorkload, kindDistributedInferenceWorkload} {
table[metav1.GroupVersionKind{
Group: apiGroupRunai,
Version: "*",
Expand Down
157 changes: 157 additions & 0 deletions pkg/podgrouper/podgrouper/plugins/lws/leader_worker_set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright 2025 NVIDIA CORPORATION
// SPDX-License-Identifier: Apache-2.0

package lws

import (
"context"
"fmt"
"strconv"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"sigs.k8s.io/controller-runtime/pkg/log"

"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgroup"
"github.com/NVIDIA/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/defaultgrouper"
)

const (
// LWS labels
lwsNameLabel = "leaderworkerset.sigs.k8s.io/name"
lwsGroupIndexLabel = "leaderworkerset.sigs.k8s.io/group-index"
lwsWorkerIndexLabel = "leaderworkerset.sigs.k8s.io/worker-index"

// LWS startup policies
leaderReadyStartupPolicy = "LeaderReady"
)

var (
logger = log.FromContext(context.Background())
)

type LwsGrouper struct {
*defaultgrouper.DefaultGrouper
}

func NewLwsGrouper(defaultGrouper *defaultgrouper.DefaultGrouper) *LwsGrouper {
return &LwsGrouper{
DefaultGrouper: defaultGrouper,
}
}

func (lg *LwsGrouper) Name() string {
return "LwsGrouper"
}

func (lg *LwsGrouper) GetPodGroupMetadata(
lwsJob *unstructured.Unstructured, pod *v1.Pod, _ ...*metav1.PartialObjectMetadata,
) (*podgroup.Metadata, error) {
podGroupMetadata, err := lg.DefaultGrouper.GetPodGroupMetadata(lwsJob, pod)
if err != nil {
return nil, err
}

// Modify the podGroup name to include the group index so each LWS replica gets its own podGroup
groupIndex, err := lg.getGroupIndexFromPod(pod)
if err != nil {
return nil, err
}
podGroupMetadata.Name = fmt.Sprintf("%s-group-%d", podGroupMetadata.Name, groupIndex)

// Get the size of one LeaderWorkerSet group (leader + workers)
groupSize, err := lg.getLwsGroupSize(lwsJob)
if err != nil {
return nil, err
}
podGroupMetadata.MinAvailable = groupSize

startupPolicy, found, err := unstructured.NestedString(lwsJob.Object, "spec", "startupPolicy")
if err != nil {
logger.V(1).Info("Failed to extract lws startup policy for %s/%s", lwsJob.GetName(), lwsJob.GetNamespace())
}
if found && startupPolicy == leaderReadyStartupPolicy {
if err := lg.handleLeaderReadyStartupPolicy(pod, podGroupMetadata); err != nil {
return podGroupMetadata, err
}
}

return podGroupMetadata, nil
}

func (lg *LwsGrouper) getGroupIndexFromPod(pod *v1.Pod) (int, error) {
groupIndexStr, found := pod.Labels[lwsGroupIndexLabel]
if !found {
return 0, fmt.Errorf("pod %s/%s does not have group index label %s", pod.Namespace, pod.Name, lwsGroupIndexLabel)
}

groupIndex, err := strconv.Atoi(groupIndexStr)
if err != nil {
return 0, fmt.Errorf("invalid group index %s for pod %s/%s: %w", groupIndexStr, pod.Namespace, pod.Name, err)
}

return groupIndex, nil
}

func (lg *LwsGrouper) getLwsGroupSize(lwsJob *unstructured.Unstructured) (int32, error) {
size, found, err := unstructured.NestedInt64(lwsJob.Object, "spec", "leaderWorkerTemplate", "size")
if err != nil {
return 0, fmt.Errorf("failed to get leaderWorkerTemplate.size from LWS %s/%s with error: %w",
lwsJob.GetNamespace(), lwsJob.GetName(), err)
}
if !found {
return 0, fmt.Errorf("leaderWorkerTemplate.size not found in LWS %s/%s", lwsJob.GetNamespace(), lwsJob.GetName())
}
if size <= 0 {
return 0, fmt.Errorf("invalid leaderWorkerTemplate.size %d in LWS %s/%s", size, lwsJob.GetNamespace(), lwsJob.GetName())
}

return int32(size), nil
}

func (lg *LwsGrouper) handleLeaderReadyStartupPolicy(
pod *v1.Pod,
podGroupMetadata *podgroup.Metadata,
) error {
// If the startup policy is LeaderReady, and this pod isn't the leader
// , then we know the leader is ready and the workers are been created (and should be scheduled)
if !lg.isPodLeader(pod) {
return nil
}

// If this pod is the leader and he is ready, then we know the leader is ready and the workers are been created (and should be scheduled)
if lg.isPodReady(pod) && pod.GetDeletionTimestamp() == nil {
return nil
}

// If the leader pod is not ready under the current policy (LeaderReady), we need to schedule the leader by himself
podGroupMetadata.MinAvailable = 1

return nil
}

func (lg *LwsGrouper) isPodLeader(pod *v1.Pod) bool {
// In LWS, leader pods typically don't have the worker-index label or have worker-index=0
workerIndexStr, hasWorkerIndex := pod.Labels[lwsWorkerIndexLabel]
if !hasWorkerIndex {
return true // No worker index means it's likely a leader
}

// If worker index is 0, it might be a leader (depending on LWS implementation)
workerIndex, err := strconv.Atoi(workerIndexStr)
if err != nil {
return false
}

return workerIndex == 0
}

func (lg *LwsGrouper) isPodReady(pod *v1.Pod) bool {
for _, condition := range pod.Status.Conditions {
if condition.Type == v1.PodReady && condition.Status == v1.ConditionTrue {
return true
}
}
return false
}
Loading
Loading