Skip to content

Commit

Permalink
Compatible with training-operator CRD. (#1024)
Browse files Browse the repository at this point in the history
Signed-off-by: Syulin7 <[email protected]>
  • Loading branch information
Syulin7 committed Jan 16, 2024
1 parent 67a9150 commit cdf1bb3
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 36 deletions.
23 changes: 18 additions & 5 deletions charts/pytorchjob/templates/pytorchjob.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,27 @@ metadata:
{{ $key }}: {{ $value | quote }}
{{- end }}
spec:
{{- if .Values.cleanPodPolicy }}
{{- if .Values.trainingOperatorCRD }}
runPolicy:
{{- if .Values.cleanPodPolicy }}
cleanPodPolicy: {{ .Values.cleanPodPolicy }}
{{- end }}
{{- if .Values.activeDeadlineSeconds }}
activeDeadlineSeconds: {{ .Values.activeDeadlineSeconds }}
{{- end }}
{{- if .Values.ttlSecondsAfterFinished }}
ttlSecondsAfterFinished: {{ .Values.ttlSecondsAfterFinished }}
{{- end }}
{{- else }}
{{- if .Values.cleanPodPolicy }}
cleanPodPolicy: {{ .Values.cleanPodPolicy }}
{{- end }}
{{- if .Values.activeDeadlineSeconds }}
{{- end }}
{{- if .Values.activeDeadlineSeconds }}
activeDeadlineSeconds: {{ .Values.activeDeadlineSeconds }}
{{- end }}
{{- if .Values.ttlSecondsAfterFinished }}
{{- end }}
{{- if .Values.ttlSecondsAfterFinished }}
ttlSecondsAfterFinished: {{ .Values.ttlSecondsAfterFinished }}
{{- end }}
{{- end }}
pytorchReplicaSpecs:
Master:
Expand Down
35 changes: 24 additions & 11 deletions charts/tfjob/templates/tfjob.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,35 @@ metadata:
{{ $key }}: {{ $value | quote }}
{{- end }}
spec:
{{- if .Values.cleanPodPolicy }}
{{- if eq "None" $cleanPodPolicy }}
{{- if .Values.trainingOperatorCRD }}
runPolicy:
{{- if .Values.cleanPodPolicy }}
cleanPodPolicy: {{ .Values.cleanPodPolicy }}
{{- end }}
{{- if .Values.activeDeadlineSeconds }}
activeDeadlineSeconds: {{ .Values.activeDeadlineSeconds }}
{{- end }}
{{- if .Values.ttlSecondsAfterFinished }}
ttlSecondsAfterFinished: {{ .Values.ttlSecondsAfterFinished }}
{{- end }}
{{- else }}
{{- if .Values.cleanPodPolicy }}
{{- if eq "None" $cleanPodPolicy }}
cleanPodPolicy: None
{{- end }}
{{- if eq "Running" $cleanPodPolicy }}
{{- end }}
{{- if eq "Running" $cleanPodPolicy }}
cleanPodPolicy: Running
{{- end }}
{{- end }}
{{- if .Values.activeDeadlineSeconds }}
{{- end }}
{{- end }}
{{- if .Values.activeDeadlineSeconds }}
activeDeadlineSeconds: {{ .Values.activeDeadlineSeconds }}
{{- end }}
{{- if .Values.startingDeadlineSeconds }}
{{- end }}
{{- if .Values.startingDeadlineSeconds }}
startingDeadlineSeconds: {{ .Values.startingDeadlineSeconds }}
{{- end }}
{{- if .Values.ttlSecondsAfterFinished }}
{{- end }}
{{- if .Values.ttlSecondsAfterFinished }}
ttlSecondsAfterFinished: {{ .Values.ttlSecondsAfterFinished }}
{{- end }}
{{- end }}
tfReplicaSpecs:
{{- if .Values.ps }}
Expand Down
3 changes: 3 additions & 0 deletions pkg/apis/types/submit_pytorchjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ type SubmitPyTorchJobArgs struct {

// Defines the TTL for cleaning up finished PytorchJobs. Defaults to infinite.
TTLSecondsAfterFinished int32 `yaml:"ttlSecondsAfterFinished,omitempty"`

// TrainingOperatorCRD compatible with training-operator crd.
TrainingOperatorCRD bool `yaml:"trainingOperatorCRD,omitempty"`
}
3 changes: 3 additions & 0 deletions pkg/apis/types/submit_tfjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ type SubmitTFJobArgs struct {

// TFRuntime stores the runtime
TFRuntime `yaml:"-"`

// TrainingOperatorCRD compatible with training-operator crd.
TrainingOperatorCRD bool `yaml:"trainingOperatorCRD,omitempty"`
}

// SubmitTensorboardArgs is used to store tensorborad information
Expand Down
11 changes: 8 additions & 3 deletions pkg/apis/utils/training.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func IsTensorFlowPod(name, ns string, pod *v1.Pod) bool {
return true
case pod.Labels[labelGroupNameV1alpha2] == "kubeflow.org":
return true
case pod.Labels[OperatorNameLabel] == "tfjob-controller":
return true
}
return false
}
Expand All @@ -49,10 +51,13 @@ func IsPyTorchPod(name, ns string, pod *v1.Pod) bool {
return false
}
// check the group name
if pod.Labels[labelPyTorchGroupName] != "kubeflow.org" {
return false
switch {
case pod.Labels[labelPyTorchGroupName] == "kubeflow.org":
return true
case pod.Labels[OperatorNameLabel] == "pytorchjob-controller":
return true
}
return true
return false
}

func IsMPIPod(name, ns string, pod *v1.Pod) bool {
Expand Down
3 changes: 3 additions & 0 deletions pkg/apis/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ const (

// deepspeedjob
deepspeedGroupName = "group-name"

// training-operator
OperatorNameLabel = "training.kubeflow.org/operator-name"
)

// GetTrainingJobTypes returns the supported training job types
Expand Down
5 changes: 5 additions & 0 deletions pkg/training/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ const (
requestGPUsOfJobAnnoKey = "requestGPUsOfJobOwner"

spotInstanceJobStatusAnnotation = "job-supervisor.kube-ai.io/job-status"

// TrainingReplicaTypeLabel training-operator replica type label
TrainingReplicaTypeLabel = "training.kubeflow.org/replica-type"
// TrainingReplicaIndexLabel training-operator replica index label
TrainingReplicaIndexLabel = "training.kubeflow.org/replica-index"
)

var (
Expand Down
5 changes: 5 additions & 0 deletions pkg/training/submit_pytorchjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"

"github.com/kubeflow/arena/pkg/apis/types"
"github.com/kubeflow/arena/pkg/k8saccesser"
"github.com/kubeflow/arena/pkg/util"
"github.com/kubeflow/arena/pkg/workflow"
log "github.com/sirupsen/logrus"
Expand All @@ -44,6 +45,10 @@ func SubmitPytorchJob(namespace string, submitArgs *types.SubmitPyTorchJobArgs)
}
// the master is also considered as a worker
submitArgs.WorkerCount = submitArgs.WorkerCount - 1

compatible := CompatibleJobCRD(k8saccesser.PytorchCRDName, "runPolicy")
submitArgs.TrainingOperatorCRD = compatible

pytorchjobChart := util.GetChartsFolder() + "/pytorchjob"
err = workflow.SubmitJob(submitArgs.Name, string(types.PytorchTrainingJob), namespace, submitArgs, pytorchjobChart, submitArgs.HelmOptions...)
if err != nil {
Expand Down
8 changes: 7 additions & 1 deletion pkg/training/submit_tfjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package training
import (
"fmt"

log "github.com/sirupsen/logrus"

"github.com/kubeflow/arena/pkg/apis/types"
"github.com/kubeflow/arena/pkg/k8saccesser"
"github.com/kubeflow/arena/pkg/util"
"github.com/kubeflow/arena/pkg/workflow"
log "github.com/sirupsen/logrus"
)

func SubmitTFJob(namespace string, submitArgs *types.SubmitTFJobArgs) (err error) {
Expand Down Expand Up @@ -49,6 +51,10 @@ func SubmitTFJob(namespace string, submitArgs *types.SubmitTFJobArgs) (err error
if submitArgs.TFRuntime != nil {
tfjob_chart = util.GetChartsFolder() + "/" + submitArgs.TFRuntime.GetChartName()
}

compatible := CompatibleJobCRD(k8saccesser.TensorflowCRDName, "runPolicy")
submitArgs.TrainingOperatorCRD = compatible

err = workflow.SubmitJob(submitArgs.Name, string(types.TFTrainingJob), namespace, submitArgs, tfjob_chart, submitArgs.HelmOptions...)
if err != nil {
return err
Expand Down
28 changes: 26 additions & 2 deletions pkg/training/trainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
package training

import (
"context"
"fmt"
"sort"
"sync"

log "github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeflow/arena/pkg/apis/config"
"github.com/kubeflow/arena/pkg/apis/types"
"github.com/kubeflow/arena/pkg/util/kubectl"
log "github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
)

var trainers map[types.TrainingJobType]Trainer
Expand Down Expand Up @@ -197,3 +200,24 @@ func CheckJobIsOwnedByTrainer(labels map[string]string) error {
}
return types.ErrNoPrivilegesToOperateJob
}

// CompatibleJobCRD Compatible with training-operator CRD.
func CompatibleJobCRD(crdName, fieldToCheck string) bool {
arenaConfiger := config.GetArenaConfiger()

tfCRD, err := arenaConfiger.GetAPIExtensionClientSet().ApiextensionsV1().CustomResourceDefinitions().Get(context.TODO(), crdName, metav1.GetOptions{})
if err != nil {
log.Errorf("Get tensorflow crd failed, error: %s", err)
return false
}

compatible := false
for _, version := range tfCRD.Spec.Versions {
if _, ok := version.Schema.OpenAPIV3Schema.Properties["spec"].Properties[fieldToCheck]; ok {
compatible = true
break
}
}

return compatible
}
14 changes: 10 additions & 4 deletions pkg/training/trainer_pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,17 @@ func (tt *PyTorchJobTrainer) GetTrainingJob(name, namespace string) (TrainingJob
}

func (tt *PyTorchJobTrainer) isChiefPod(pytorchjob *pytorchv1.PyTorchJob, item *v1.Pod) bool {
if item.Labels[pytorchReplicaTypeLabel] != "master" {
return false
isChiefPod := false

if val, ok := item.Labels[pytorchReplicaTypeLabel]; ok && val == "master" {
isChiefPod = true
}
log.Debugf("the pytorchjob %s with labels master", item.Name)
return true

if val, ok := item.Labels[TrainingReplicaTypeLabel]; ok && val == "master" {
isChiefPod = true
}

return isChiefPod
}

// check Labels: release==pytorchjob.name/app=="pytorchjob", namespace
Expand Down
24 changes: 14 additions & 10 deletions pkg/training/trainer_tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,31 +330,35 @@ func (tt *TensorFlowJobTrainer) GetTrainingJob(name, namespace string) (Training
}

func (tt *TensorFlowJobTrainer) isChiefPod(tfjob *tfv1.TFJob, item *v1.Pod) bool {
isChiefPod := false

// find chief pod in chief mode
if _, ok := tfjob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeChief]; ok {
log.Debugf("The distributed tensorflow is in chief mode")
if val, ok := item.Labels[tfReplicaTypeLabel]; ok && (val == "chief") {
log.Debugf("the tfjob %s with labels %s is the chief pod", item.Name, val)
return true
} else {
return false
}
if val, ok := item.Labels[TrainingReplicaTypeLabel]; ok && (val == "chief") {
log.Debugf("the tfjob %s with labels %s is the chief pod", item.Name, val)
return true
}
return false
}

if val, ok := item.Labels[tfReplicaTypeLabel]; ok && (val == "worker") {
log.Debugf("the tfjob %s with labels %s is the chief pod", item.Name, val)
} else {
return false
if val, ok := item.Labels[tfReplicaIndexLabel]; ok && (val == "0") {
isChiefPod = true
}
}

if val, ok := item.Labels[tfReplicaIndexLabel]; ok && (val == "0") {
log.Debugf("the chief pod of tfjob %s with labels %s is found.", item.Name, val)
} else {
return false
if val, ok := item.Labels[TrainingReplicaTypeLabel]; ok && (val == "worker") {
if val, ok := item.Labels[TrainingReplicaIndexLabel]; ok && (val == "0") {
isChiefPod = true
}
}

return true
return isChiefPod
}

func (tt *TensorFlowJobTrainer) isTensorFlowJob(name, ns string, item *tfv1.TFJob) bool {
Expand Down

0 comments on commit cdf1bb3

Please sign in to comment.