Skip to content

Commit

Permalink
support update --data in kserve serving job (#1049)
Browse files Browse the repository at this point in the history
Signed-off-by: zibai <[email protected]>
  • Loading branch information
gujingit committed Mar 18, 2024
1 parent b7f0ecf commit 8b05634
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
1 change: 1 addition & 0 deletions pkg/apis/types/update_serving.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type CommonUpdateServingArgs struct {
Tolerations []TolerationArgs `yaml:"tolerations"` // --toleration
Shell string `yaml:"shell"` // --shell
Command string `yaml:"command"` // --command
ModelDirs map[string]string `yaml:"modelDirs"` // --data
}

type UpdateTensorFlowServingArgs struct {
Expand Down
33 changes: 32 additions & 1 deletion pkg/argsbuilder/update_serving.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package argsbuilder
import (
"fmt"
"github.com/kubeflow/arena/pkg/apis/types"
"github.com/kubeflow/arena/pkg/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"k8s.io/apimachinery/pkg/api/resource"
Expand Down Expand Up @@ -68,6 +69,7 @@ func (s *UpdateServingArgsBuilder) AddCommandFlags(command *cobra.Command) {
envs []string
selectors []string
tolerations []string
dataset []string
)

command.Flags().StringVar(&s.args.Name, "name", "", "the serving name")
Expand All @@ -85,11 +87,14 @@ func (s *UpdateServingArgsBuilder) AddCommandFlags(command *cobra.Command) {
command.Flags().StringVar(&s.args.Command, "command", "", "the command will inject to container's command.")
command.Flags().StringArrayVarP(&selectors, "selector", "", []string{}, `assigning jobs to some k8s particular nodes, usage: "--selector=key=value" or "--selector key=value" `)
command.Flags().StringArrayVarP(&tolerations, "toleration", "", []string{}, `tolerate some k8s nodes with taints,usage: "--toleration key=value:effect,operator" or "--toleration all" `)
command.Flags().StringArrayVarP(&dataset, "data", "d", []string{}, "specify the trained models datasource to mount for serving, like <name_of_datasource>:<mount_point_on_job>")

s.AddArgValue("env", &envs).
AddArgValue("annotation", &annotations).
AddArgValue("selector", &selectors).
AddArgValue("label", &labels).
AddArgValue("toleration", &tolerations)
AddArgValue("toleration", &tolerations).
AddArgValue("data", &dataset)
}

func (s *UpdateServingArgsBuilder) PreBuild() error {
Expand Down Expand Up @@ -130,6 +135,10 @@ func (s *UpdateServingArgsBuilder) PreBuild() error {
return err
}

if err := s.setDataSet(); err != nil {
return err
}

if err := s.check(); err != nil {
return err
}
Expand Down Expand Up @@ -260,6 +269,28 @@ func (s *UpdateServingArgsBuilder) setLabels() error {
return nil
}

// setDataSets is used to handle option --data
func (s *UpdateServingArgsBuilder) setDataSet() error {
s.args.ModelDirs = map[string]string{}
argKey := "data"
var dataSet *[]string
value, ok := s.argValues[argKey]
if !ok {
return nil
}
dataSet = value.(*[]string)
log.Debugf("dataset: %v", *dataSet)
if len(*dataSet) <= 0 {
return nil
}
err := util.ValidateDatasets(*dataSet)
if err != nil {
return err
}
s.args.ModelDirs = transformSliceToMap(*dataSet, ":")
return nil
}

func (s *UpdateServingArgsBuilder) checkNamespace() error {
if s.args.Namespace == "" {
return fmt.Errorf("namespace not set, please set it")
Expand Down
24 changes: 24 additions & 0 deletions pkg/serving/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,30 @@ func setInferenceServiceForCustomModel(args *types.UpdateKServeArgs, inferenceSe
inferenceService.Spec.Predictor.Containers[0].Image = args.Image
}

//set volume
if len(args.ModelDirs) != 0 {
log.Debugf("update modelDirs: [%+v]", args.ModelDirs)
var volumes []v1.Volume
var volumeMounts []v1.VolumeMount

for pvName, mountPath := range args.ModelDirs {
volumes = append(volumes, v1.Volume{
Name: pvName,
VolumeSource: v1.VolumeSource{
PersistentVolumeClaim: &v1.PersistentVolumeClaimVolumeSource{
ClaimName: pvName,
},
},
})
volumeMounts = append(volumeMounts, v1.VolumeMount{
Name: pvName,
MountPath: mountPath,
})
}
inferenceService.Spec.Predictor.Containers[0].VolumeMounts = volumeMounts
inferenceService.Spec.Predictor.Volumes = volumes
}

// set resources requests
resourceRequests := inferenceService.Spec.Predictor.Containers[0].Resources.Requests
if resourceRequests == nil {
Expand Down

0 comments on commit 8b05634

Please sign in to comment.