Skip to content

Commit

Permalink
feat: add backend param for triton serving (#1039)
Browse files Browse the repository at this point in the history
  • Loading branch information
gujingit committed Feb 18, 2024
1 parent ed2aea2 commit f27a678
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 16 deletions.
17 changes: 17 additions & 0 deletions charts/triton/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,20 @@ Create chart name and version as used by the chart label.
{{- define "nvidia-triton-server.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" -}}
{{- end -}}

{{/*
Return tritonserver image
*/}}
{{- define "triton.image" -}}
{{- if .Values.image }}
{{- .Values.image -}}
{{- else }}
{{- if eq .Values.backend "vllm" }}
{{- "nvcr.io/nvidia/tritonserver:24.01-vllm-python-py3" -}}
{{- else if eq .Values.backend "trt-llm" }}
{{- "nvcr.io/nvidia/tritonserver:24.01-trtllm-python-py3" -}}
{{- else }}
{{- "nvcr.io/nvidia/tritonserver:24.01-py3" -}}
{{- end }}
{{- end }}
{{- end -}}
4 changes: 1 addition & 3 deletions charts/triton/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ spec:
{{- end }}
containers:
- name: tritonserver
{{- if .Values.image }}
image: "{{ .Values.image }}"
{{- end }}
image: {{ include "triton.image" . }}
{{- if .Values.imagePullPolicy }}
imagePullPolicy: "{{ .Values.imagePullPolicy }}"
{{- end }}
Expand Down
2 changes: 0 additions & 2 deletions charts/triton/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ serviceType: ClusterIP
servingName:
servingVersion:

image: "nvcr.io/nvidia/tritonserver:24.01-py3"

imagePullPolicy: "IfNotPresent"

cpu: 1.0
Expand Down
1 change: 0 additions & 1 deletion pkg/apis/serving/triton_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func NewTritonServingJobBuilder() *TritonServingJobBuilder {
GrpcPort: 8001,
MetricsPort: 8002,
CommonServingArgs: types.CommonServingArgs{
Image: argsbuilder.DefaultTritonServingImage,
ImagePullPolicy: "IfNotPresent",
Replicas: 1,
Namespace: "default",
Expand Down
1 change: 1 addition & 0 deletions pkg/apis/types/serving.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ type SeldonServingArgs struct {
}

type TritonServingArgs struct {
Backend string `yaml:"backend"` // --backend
ModelRepository string `yaml:"modelRepository"` // --model-repository
MetricsPort int `yaml:"metricsPort"` // --metrics-port
HttpPort int `yaml:"httpPort"` // --http-port
Expand Down
16 changes: 6 additions & 10 deletions pkg/argsbuilder/serving_triton.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ import (
"github.com/spf13/cobra"
)

const (
DefaultTritonServingImage = "nvcr.io/nvidia/tritonserver:24.01-py3"
)

type TritonServingArgsBuilder struct {
args *types.TritonServingArgs
argValues map[string]interface{}
Expand All @@ -43,7 +39,6 @@ func NewTritonServingArgsBuilder(args *types.TritonServingArgs) ArgsBuilder {
s.AddSubBuilder(
NewServingArgsBuilder(&s.args.CommonServingArgs),
)
s.AddArgValue("default-image", DefaultTritonServingImage)
return s
}

Expand Down Expand Up @@ -72,6 +67,7 @@ func (s *TritonServingArgsBuilder) AddCommandFlags(command *cobra.Command) {
s.subBuilders[name].AddCommandFlags(command)
}
var loadModels []string
command.Flags().StringVar(&s.args.Backend, "backend", "", "the backend type of triton server. Valid values: [vllm|trt-llm]")
command.Flags().StringVar(&s.args.ModelRepository, "model-repository", "", "the path of triton model path")
command.Flags().IntVar(&s.args.HttpPort, "http-port", 8000, "the port of http serving server")
command.Flags().IntVar(&s.args.GrpcPort, "grpc-port", 8001, "the port of grpc serving server")
Expand Down Expand Up @@ -112,14 +108,14 @@ func (s *TritonServingArgsBuilder) Build() error {
}

func (s *TritonServingArgsBuilder) validate() (err error) {
if s.args.Image == "" {
return fmt.Errorf("image must be specified")
}
/*
if s.args.Backend != "" {
if s.args.Backend != "vllm" && s.args.Backend != "trt-llm" {
return fmt.Errorf("backend %s is Invalid. Triton backend only supports vllm or trt-llm", s.args.Backend)
}
if s.args.GPUCount == 0 {
return fmt.Errorf("--gpus must be specific at least 1 GPU")
}
*/
}
return nil
}

Expand Down

0 comments on commit f27a678

Please sign in to comment.