diff --git a/charts/triton/templates/_helpers.tpl b/charts/triton/templates/_helpers.tpl index e8b566053..326e3baa3 100755 --- a/charts/triton/templates/_helpers.tpl +++ b/charts/triton/templates/_helpers.tpl @@ -30,3 +30,20 @@ Create chart name and version as used by the chart label. {{- define "nvidia-triton-server.chart" -}} {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" -}} {{- end -}} + +{{/* +Return tritonserver image +*/}} +{{- define "triton.image" -}} +{{- if .Values.image }} +{{- .Values.image -}} +{{- else }} +{{- if eq .Values.backend "vllm" }} +{{- "nvcr.io/nvidia/tritonserver:24.01-vllm-python-py3" -}} +{{- else if eq .Values.backend "trt-llm" }} +{{- "nvcr.io/nvidia/tritonserver:24.01-trtllm-python-py3" -}} +{{- else }} +{{- "nvcr.io/nvidia/tritonserver:24.01-py3" -}} +{{- end }} +{{- end }} +{{- end -}} diff --git a/charts/triton/templates/deployment.yaml b/charts/triton/templates/deployment.yaml index 392b6b3ba..28e066fd3 100755 --- a/charts/triton/templates/deployment.yaml +++ b/charts/triton/templates/deployment.yaml @@ -85,9 +85,7 @@ spec: {{- end }} containers: - name: tritonserver - {{- if .Values.image }} - image: "{{ .Values.image }}" - {{- end }} + image: {{ include "triton.image" . }} {{- if .Values.imagePullPolicy }} imagePullPolicy: "{{ .Values.imagePullPolicy }}" {{- end }} diff --git a/charts/triton/values.yaml b/charts/triton/values.yaml index d0745f932..9bea77247 100755 --- a/charts/triton/values.yaml +++ b/charts/triton/values.yaml @@ -11,8 +11,6 @@ serviceType: ClusterIP servingName: servingVersion: -image: "nvcr.io/nvidia/tritonserver:24.01-py3" - imagePullPolicy: "IfNotPresent" cpu: 1.0 diff --git a/pkg/apis/serving/triton_builder.go b/pkg/apis/serving/triton_builder.go index afeb9d555..8e54f4145 100644 --- a/pkg/apis/serving/triton_builder.go +++ b/pkg/apis/serving/triton_builder.go @@ -20,7 +20,6 @@ func NewTritonServingJobBuilder() *TritonServingJobBuilder { GrpcPort: 8001, MetricsPort: 8002, CommonServingArgs: types.CommonServingArgs{ - Image: argsbuilder.DefaultTritonServingImage, ImagePullPolicy: "IfNotPresent", Replicas: 1, Namespace: "default", diff --git a/pkg/apis/types/serving.go b/pkg/apis/types/serving.go index 0e2ffc735..9b9fc652e 100644 --- a/pkg/apis/types/serving.go +++ b/pkg/apis/types/serving.go @@ -253,6 +253,7 @@ type SeldonServingArgs struct { } type TritonServingArgs struct { + Backend string `yaml:"backend"` // --backend ModelRepository string `yaml:"modelRepository"` // --model-repository MetricsPort int `yaml:"metricsPort"` // --metrics-port HttpPort int `yaml:"httpPort"` // --http-port diff --git a/pkg/argsbuilder/serving_triton.go b/pkg/argsbuilder/serving_triton.go index d8b388709..4ca1233d3 100644 --- a/pkg/argsbuilder/serving_triton.go +++ b/pkg/argsbuilder/serving_triton.go @@ -23,10 +23,6 @@ import ( "github.com/spf13/cobra" ) -const ( - DefaultTritonServingImage = "nvcr.io/nvidia/tritonserver:24.01-py3" -) - type TritonServingArgsBuilder struct { args *types.TritonServingArgs argValues map[string]interface{} @@ -43,7 +39,6 @@ func NewTritonServingArgsBuilder(args *types.TritonServingArgs) ArgsBuilder { s.AddSubBuilder( NewServingArgsBuilder(&s.args.CommonServingArgs), ) - s.AddArgValue("default-image", DefaultTritonServingImage) return s } @@ -72,6 +67,7 @@ func (s *TritonServingArgsBuilder) AddCommandFlags(command *cobra.Command) { s.subBuilders[name].AddCommandFlags(command) } var loadModels []string + command.Flags().StringVar(&s.args.Backend, "backend", "", "the backend type of triton server. Valid values: [vllm|trt-llm]") command.Flags().StringVar(&s.args.ModelRepository, "model-repository", "", "the path of triton model path") command.Flags().IntVar(&s.args.HttpPort, "http-port", 8000, "the port of http serving server") command.Flags().IntVar(&s.args.GrpcPort, "grpc-port", 8001, "the port of grpc serving server") @@ -112,14 +108,14 @@ func (s *TritonServingArgsBuilder) Build() error { } func (s *TritonServingArgsBuilder) validate() (err error) { - if s.args.Image == "" { - return fmt.Errorf("image must be specified") - } - /* + if s.args.Backend != "" { + if s.args.Backend != "vllm" && s.args.Backend != "trt-llm" { + return fmt.Errorf("backend %s is Invalid. Triton backend only supports vllm or trt-llm", s.args.Backend) + } if s.args.GPUCount == 0 { return fmt.Errorf("--gpus must be specific at least 1 GPU") } - */ + } return nil }