Skip to content

Commit

Permalink
feat(transport): ttheader streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
felix021 committed Jun 6, 2024
1 parent 188eeba commit 83916e7
Show file tree
Hide file tree
Showing 78 changed files with 8,833 additions and 386 deletions.
8 changes: 2 additions & 6 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,7 @@ func newMockCliTransHandlerFactory(ctrl *gomock.Controller) remote.ClientTransHa
handler.EXPECT().SetPipeline(gomock.Any()).AnyTimes()
handler.EXPECT().OnError(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
handler.EXPECT().OnInactive(gomock.Any(), gomock.Any()).AnyTimes()
factory := mocksremote.NewMockClientTransHandlerFactory(ctrl)
factory.EXPECT().NewTransHandler(gomock.Any()).DoAndReturn(func(opt *remote.ClientOption) (remote.ClientTransHandler, error) {
return handler, nil
}).AnyTimes()
return factory
return mocks.NewMockCliTransHandlerFactory(handler)
}

func newMockClient(tb testing.TB, ctrl *gomock.Controller, extra ...Option) Client {
Expand All @@ -131,7 +127,7 @@ func newMockClient(tb testing.TB, ctrl *gomock.Controller, extra ...Option) Clie

svcInfo := mocks.ServiceInfo()
cli, err := NewClient(svcInfo, opts...)
test.Assert(tb, err == nil)
test.Assert(tb, err == nil, err)

return cli
}
Expand Down
8 changes: 8 additions & 0 deletions client/option_advanced.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,11 @@ func WithGRPCTLSConfig(tlsConfig *tls.Config) Option {
o.GRPCConnectOpts.TLSConfig = grpc.TLSConfig(tlsConfig)
}}
}

// WithTTHeaderFrameMetaHandler sets the FrameMetaHandler for TTHeader Streaming
func WithTTHeaderFrameMetaHandler(h remote.FrameMetaHandler) Option {
return Option{F: func(o *client.Options, di *utils.Slice) {
di.Push(fmt.Sprintf("WithTTHeaderFrameMetaHandler(%T)", h))
o.RemoteOpt.TTHeaderFrameMetaHandler = h
}}
}
29 changes: 22 additions & 7 deletions client/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ package client

import (
"context"
"fmt"
"io"
"sync/atomic"

"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
"github.com/cloudwego/kitex/pkg/serviceinfo"

"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/remotecli"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/streaming"
)

Expand Down Expand Up @@ -71,10 +71,16 @@ func (kc *kClient) invokeRecvEndpoint() endpoint.RecvEndpoint {
}

func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) {
handler, err := kc.opt.RemoteOpt.CliHandlerFactory.NewTransHandler(kc.opt.RemoteOpt)
handler, err := kc.newStreamClientTransHandler()
if err != nil {
return nil, err
// if a ClientTransHandler does not support streaming, the error will be returned when Kitex is used to
// send streaming requests. This saves the pain for each ClientTransHandler to implement the interface,
// with no impact on existing ClientTransHandler implementations.
return func(ctx context.Context, req, resp interface{}) (err error) {
panic(err)
}, nil
}

for _, h := range kc.opt.MetaHandlers {
if shdlr, ok := h.(remote.StreamingMetaHandler); ok {
kc.opt.RemoteOpt.StreamingMetaHandlers = append(kc.opt.RemoteOpt.StreamingMetaHandlers, shdlr)
Expand All @@ -97,6 +103,15 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) {
}, nil
}

func (kc *kClient) newStreamClientTransHandler() (remote.ClientTransHandler, error) {
handlerFactory, ok := kc.opt.RemoteOpt.CliHandlerFactory.(remote.ClientStreamTransHandlerFactory)
if !ok {
return nil, fmt.Errorf("remote.ClientStreamTransHandlerFactory is not implement by %T",
kc.opt.RemoteOpt.CliHandlerFactory)
}
return handlerFactory.NewStreamTransHandler(kc.opt.RemoteOpt)
}

func (kc *kClient) getStreamingMode(ri rpcinfo.RPCInfo) serviceinfo.StreamingMode {
methodInfo := kc.svcInfo.MethodInfo(ri.Invocation().MethodName())
if methodInfo == nil {
Expand Down Expand Up @@ -204,7 +219,7 @@ func (s *stream) DoFinish(err error) {
err = nil
}
if s.scm != nil {
s.scm.ReleaseConn(err, s.ri)
s.scm.ReleaseConn(s, err, s.ri)
}
s.kc.opt.TracerCtl.DoFinish(s.Context(), s.ri, err)
}
Expand Down
70 changes: 61 additions & 9 deletions client/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/remotecli"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand Down Expand Up @@ -72,6 +73,11 @@ func TestStream(t *testing.T) {
test.Assert(t, err == nil, err)
}

func init() {
mocks.DefaultNewStreamFunc = nphttp2.NewStream
mocks.DefaultNewStreamTransHandlerFunc = nphttp2.NewCliTransHandlerFactory().NewTransHandler
}

func TestStreaming(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down Expand Up @@ -102,7 +108,8 @@ func TestStreaming(t *testing.T) {
connpool := mock_remote.NewMockConnPool(ctrl)
connpool.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(conn, nil)
cliInfo.ConnPool = connpool
s, cr, _ := remotecli.NewStream(ctx, mockRPCInfo, new(mocks.MockCliTransHandler), cliInfo)
handler := new(mocks.MockCliTransHandler)
s, cr, _ := remotecli.NewStream(ctx, mockRPCInfo, handler, cliInfo)
stream := newStream(
s, cr, kc, mockRPCInfo, serviceinfo.StreamingBidirectional,
func(stream streaming.Stream, message interface{}) (err error) {
Expand Down Expand Up @@ -177,10 +184,11 @@ func TestClosedClient(t *testing.T) {

type mockStream struct {
streaming.Stream
ctx context.Context
close func() error
header func() (metadata.MD, error)
recv func(msg interface{}) error
ctx context.Context
trailer func() metadata.MD
close func() error
header func() (metadata.MD, error)
recv func(msg interface{}) error
}

func (s *mockStream) Context() context.Context {
Expand All @@ -195,6 +203,10 @@ func (s *mockStream) RecvMsg(msg interface{}) error {
return s.recv(msg)
}

func (s *mockStream) Trailer() metadata.MD {
return s.trailer()
}

func (s *mockStream) Close() error {
return s.close()
}
Expand Down Expand Up @@ -322,7 +334,8 @@ func Test_stream_RecvMsg(t *testing.T) {
})

t.Run("no-error-client-streaming", func(t *testing.T) {
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, rpcinfo.NewRPCStats())
cfg := rpcinfo.NewRPCConfig()
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, rpcinfo.NewRPCStats())
st := &mockStream{
ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri),
}
Expand Down Expand Up @@ -470,7 +483,7 @@ func Test_stream_DoFinish(t *testing.T) {
defer ctrl.Finish()

t.Run("no-error", func(t *testing.T) {
ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats())
ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
st := &mockStream{
ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri),
}
Expand Down Expand Up @@ -500,7 +513,7 @@ func Test_stream_DoFinish(t *testing.T) {
})

t.Run("EOF", func(t *testing.T) {
ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats())
ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
st := &mockStream{
ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri),
}
Expand Down Expand Up @@ -530,7 +543,7 @@ func Test_stream_DoFinish(t *testing.T) {
})

t.Run("biz-status-error", func(t *testing.T) {
ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats())
ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
st := &mockStream{
ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri),
}
Expand Down Expand Up @@ -637,3 +650,42 @@ func Test_isRPCError(t *testing.T) {
test.Assert(t, isRPCError(errors.New("error")))
})
}

type mockHandlerFactory struct {
remote.ClientTransHandlerFactory
}

var _ remote.ClientStreamTransHandlerFactory = (*mockHandlerFactoryWithNewStream)(nil)

type mockHandlerFactoryWithNewStream struct {
remote.ClientTransHandlerFactory
}

func (m mockHandlerFactoryWithNewStream) NewStreamTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) {
return nil, nil
}

func Test_kClient_newStreamClientTransHandler(t *testing.T) {
t.Run("not-implemented-NewStream", func(t *testing.T) {
kc := &kClient{
opt: &client.Options{
RemoteOpt: &remote.ClientOption{
CliHandlerFactory: &mockHandlerFactory{},
},
},
}
_, err := kc.newStreamClientTransHandler()
test.Assert(t, err != nil, err) // not implemented
})
t.Run("implemented-NewStream", func(t *testing.T) {
kc := &kClient{
opt: &client.Options{
RemoteOpt: &remote.ClientOption{
CliHandlerFactory: &mockHandlerFactoryWithNewStream{},
},
},
}
_, err := kc.newStreamClientTransHandler()
test.Assert(t, err == nil)
})
}
44 changes: 44 additions & 0 deletions client/streamclient/ttheader_option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package streamclient

import (
"fmt"

"github.com/cloudwego/kitex/client"
"github.com/cloudwego/kitex/pkg/utils"
)

// WithTTHeaderStreamingWaitMetaFrame instructs Kitex client to wait for meta frame
func WithTTHeaderStreamingWaitMetaFrame(wait bool) Option {
return Option{
F: func(o *client.Options, di *utils.Slice) {
di.Push(fmt.Sprintf("WithTTHeaderStreamingWaitMetaFrame(%v)", wait))
o.RemoteOpt.TTHeaderStreamingWaitMetaFrame = wait
},
}
}

// WithTTHeaderStreamingGRPCCompatible instructs Kitex client to send grpc style metadata to the server
func WithTTHeaderStreamingGRPCCompatible() Option {
return Option{
F: func(o *client.Options, di *utils.Slice) {
di.Push("WithTTHeaderStreamingGRPCCompatible()")
o.RemoteOpt.TTHeaderStreamingGRPCCompatible = true
},
}
}
47 changes: 42 additions & 5 deletions internal/mocks/transhandlerclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,72 @@ import (
"context"
"net"

mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/remote"
kitex "github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/streaming"
)

var (
_ remote.ClientStreamAllocator = (*MockCliTransHandler)(nil)
_ remote.ClientTransHandlerFactory = (*mockCliTransHandlerFactory)(nil)
_ remote.ClientStreamTransHandlerFactory = (*mockCliTransHandlerFactory)(nil)
DefaultNewStreamFunc func(
ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter,
) streaming.Stream = nil

DefaultNewStreamTransHandlerFunc func(opt *remote.ClientOption) (remote.ClientTransHandler, error) = nil
)

type mockCliTransHandlerFactory struct {
hdlr *MockCliTransHandler
hdlr *mocksremote.MockClientTransHandler
NewStreamTransHandlerFunc func(opt *remote.ClientOption) (remote.ClientTransHandler, error)
}

func (f *mockCliTransHandlerFactory) NewStreamTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) {
if f.NewStreamTransHandlerFunc != nil {
return f.NewStreamTransHandlerFunc(opt)
}
if DefaultNewStreamTransHandlerFunc != nil {
return DefaultNewStreamTransHandlerFunc(opt)
}
panic("NewStreamTransHandlerFunc is not available in mocks.mockCliTransHandlerFactory")
}

// NewMockCliTransHandlerFactory .
func NewMockCliTransHandlerFactory(hdrl *MockCliTransHandler) remote.ClientTransHandlerFactory {
return &mockCliTransHandlerFactory{hdrl}
func NewMockCliTransHandlerFactory(handler *mocksremote.MockClientTransHandler) remote.ClientTransHandlerFactory {
return &mockCliTransHandlerFactory{
hdlr: handler,
}
}

func (f *mockCliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) {
f.hdlr.opt = opt
return f.hdlr, nil
}

// MockCliTransHandler .
type MockCliTransHandler struct {
opt *remote.ClientOption
transPipe *remote.TransPipeline

WriteFunc func(ctx context.Context, conn net.Conn, send remote.Message) (nctx context.Context, err error)

ReadFunc func(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error)

OnMessageFunc func(ctx context.Context, args, result remote.Message) (context.Context, error)

NewStreamFunc func(ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter) (streaming.Stream, error)
}

func (t *MockCliTransHandler) NewStream(ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter) (streaming.Stream, error) {
if t.NewStreamFunc != nil {
return t.NewStreamFunc(ctx, svcInfo, conn, handler)
}
if DefaultNewStreamFunc != nil {
return DefaultNewStreamFunc(ctx, svcInfo, conn, handler), nil
}
panic("NewStreamFunc is not available in mocks.MockCliTransHandler")
}

// Write implements the remote.TransHandler interface.
Expand Down
10 changes: 8 additions & 2 deletions internal/server/remote_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ import (

func newServerRemoteOption() *remote.ServerOption {
return &remote.ServerOption{
TransServerFactory: netpoll.NewTransServerFactory(),
SvrHandlerFactory: detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()),
TransServerFactory: netpoll.NewTransServerFactory(),
SvrHandlerFactory: detection.NewSvrTransHandlerFactory(
// default transHandler
netpoll.NewSvrTransHandlerFactory(),
// detectable transHandlers
netpoll.NewTTHeaderStreamingSvrTransHandlerFactory(),
nphttp2.NewSvrTransHandlerFactory(),
),
Codec: codec.NewDefaultCodec(),
Address: defaultAddress,
ExitWaitTime: defaultExitWaitTime,
Expand Down
10 changes: 8 additions & 2 deletions internal/server/remote_option_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ import (

func newServerRemoteOption() *remote.ServerOption {
return &remote.ServerOption{
TransServerFactory: gonet.NewTransServerFactory(),
SvrHandlerFactory: detection.NewSvrTransHandlerFactory(gonet.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()),
TransServerFactory: gonet.NewTransServerFactory(),
SvrHandlerFactory: detection.NewSvrTransHandlerFactory(
// default transHandler
gonet.NewSvrTransHandlerFactory(),
// detectable transHandlers
gonet.NewTTHeaderStreamingSvrTransHandlerFactory(),
nphttp2.NewSvrTransHandlerFactory(),
),
Codec: codec.NewDefaultCodec(),
Address: defaultAddress,
ExitWaitTime: defaultExitWaitTime,
Expand Down
Loading

0 comments on commit 83916e7

Please sign in to comment.