diff --git a/client/client_test.go b/client/client_test.go index 64c0e3cfc8..f15af2487a 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -113,11 +113,7 @@ func newMockCliTransHandlerFactory(ctrl *gomock.Controller) remote.ClientTransHa handler.EXPECT().SetPipeline(gomock.Any()).AnyTimes() handler.EXPECT().OnError(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() handler.EXPECT().OnInactive(gomock.Any(), gomock.Any()).AnyTimes() - factory := mocksremote.NewMockClientTransHandlerFactory(ctrl) - factory.EXPECT().NewTransHandler(gomock.Any()).DoAndReturn(func(opt *remote.ClientOption) (remote.ClientTransHandler, error) { - return handler, nil - }).AnyTimes() - return factory + return mocks.NewMockCliTransHandlerFactory(handler) } func newMockClient(tb testing.TB, ctrl *gomock.Controller, extra ...Option) Client { @@ -131,7 +127,7 @@ func newMockClient(tb testing.TB, ctrl *gomock.Controller, extra ...Option) Clie svcInfo := mocks.ServiceInfo() cli, err := NewClient(svcInfo, opts...) - test.Assert(tb, err == nil) + test.Assert(tb, err == nil, err) return cli } diff --git a/client/option_advanced.go b/client/option_advanced.go index 54a48cec31..f5d24487ca 100644 --- a/client/option_advanced.go +++ b/client/option_advanced.go @@ -245,3 +245,11 @@ func WithGRPCTLSConfig(tlsConfig *tls.Config) Option { o.GRPCConnectOpts.TLSConfig = grpc.TLSConfig(tlsConfig) }} } + +// WithTTHeaderFrameMetaHandler sets the FrameMetaHandler for TTHeader Streaming +func WithTTHeaderFrameMetaHandler(h remote.FrameMetaHandler) Option { + return Option{F: func(o *client.Options, di *utils.Slice) { + di.Push(fmt.Sprintf("WithTTHeaderFrameMetaHandler(%T)", h)) + o.RemoteOpt.TTHeaderFrameMetaHandler = h + }} +} diff --git a/client/stream.go b/client/stream.go index a184ac7ecc..60d9a21001 100644 --- a/client/stream.go +++ b/client/stream.go @@ -18,17 +18,17 @@ package client import ( "context" + "fmt" "io" "sync/atomic" - "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" - "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/remotecli" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streaming" ) @@ -71,10 +71,16 @@ func (kc *kClient) invokeRecvEndpoint() endpoint.RecvEndpoint { } func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { - handler, err := kc.opt.RemoteOpt.CliHandlerFactory.NewTransHandler(kc.opt.RemoteOpt) + handler, err := kc.newStreamClientTransHandler() if err != nil { - return nil, err + // if a ClientTransHandler does not support streaming, the error will be returned when Kitex is used to + // send streaming requests. This saves the pain for each ClientTransHandler to implement the interface, + // with no impact on existing ClientTransHandler implementations. + return func(ctx context.Context, req, resp interface{}) (err error) { + panic(err) + }, nil } + for _, h := range kc.opt.MetaHandlers { if shdlr, ok := h.(remote.StreamingMetaHandler); ok { kc.opt.RemoteOpt.StreamingMetaHandlers = append(kc.opt.RemoteOpt.StreamingMetaHandlers, shdlr) @@ -97,6 +103,15 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { }, nil } +func (kc *kClient) newStreamClientTransHandler() (remote.ClientTransHandler, error) { + handlerFactory, ok := kc.opt.RemoteOpt.CliHandlerFactory.(remote.ClientStreamTransHandlerFactory) + if !ok { + return nil, fmt.Errorf("remote.ClientStreamTransHandlerFactory is not implement by %T", + kc.opt.RemoteOpt.CliHandlerFactory) + } + return handlerFactory.NewStreamTransHandler(kc.opt.RemoteOpt) +} + func (kc *kClient) getStreamingMode(ri rpcinfo.RPCInfo) serviceinfo.StreamingMode { methodInfo := kc.svcInfo.MethodInfo(ri.Invocation().MethodName()) if methodInfo == nil { @@ -204,7 +219,7 @@ func (s *stream) DoFinish(err error) { err = nil } if s.scm != nil { - s.scm.ReleaseConn(err, s.ri) + s.scm.ReleaseConn(s, err, s.ri) } s.kc.opt.TracerCtl.DoFinish(s.Context(), s.ri, err) } diff --git a/client/stream_test.go b/client/stream_test.go index 9c07647d2a..ccdd7c902a 100644 --- a/client/stream_test.go +++ b/client/stream_test.go @@ -32,6 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/remotecli" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -72,6 +73,11 @@ func TestStream(t *testing.T) { test.Assert(t, err == nil, err) } +func init() { + mocks.DefaultNewStreamFunc = nphttp2.NewStream + mocks.DefaultNewStreamTransHandlerFunc = nphttp2.NewCliTransHandlerFactory().NewTransHandler +} + func TestStreaming(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -102,7 +108,8 @@ func TestStreaming(t *testing.T) { connpool := mock_remote.NewMockConnPool(ctrl) connpool.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(conn, nil) cliInfo.ConnPool = connpool - s, cr, _ := remotecli.NewStream(ctx, mockRPCInfo, new(mocks.MockCliTransHandler), cliInfo) + handler := new(mocks.MockCliTransHandler) + s, cr, _ := remotecli.NewStream(ctx, mockRPCInfo, handler, cliInfo) stream := newStream( s, cr, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, func(stream streaming.Stream, message interface{}) (err error) { @@ -177,10 +184,11 @@ func TestClosedClient(t *testing.T) { type mockStream struct { streaming.Stream - ctx context.Context - close func() error - header func() (metadata.MD, error) - recv func(msg interface{}) error + ctx context.Context + trailer func() metadata.MD + close func() error + header func() (metadata.MD, error) + recv func(msg interface{}) error } func (s *mockStream) Context() context.Context { @@ -195,6 +203,10 @@ func (s *mockStream) RecvMsg(msg interface{}) error { return s.recv(msg) } +func (s *mockStream) Trailer() metadata.MD { + return s.trailer() +} + func (s *mockStream) Close() error { return s.close() } @@ -322,7 +334,8 @@ func Test_stream_RecvMsg(t *testing.T) { }) t.Run("no-error-client-streaming", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, rpcinfo.NewRPCStats()) + cfg := rpcinfo.NewRPCConfig() + ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -470,7 +483,7 @@ func Test_stream_DoFinish(t *testing.T) { defer ctrl.Finish() t.Run("no-error", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -500,7 +513,7 @@ func Test_stream_DoFinish(t *testing.T) { }) t.Run("EOF", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -530,7 +543,7 @@ func Test_stream_DoFinish(t *testing.T) { }) t.Run("biz-status-error", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -637,3 +650,42 @@ func Test_isRPCError(t *testing.T) { test.Assert(t, isRPCError(errors.New("error"))) }) } + +type mockHandlerFactory struct { + remote.ClientTransHandlerFactory +} + +var _ remote.ClientStreamTransHandlerFactory = (*mockHandlerFactoryWithNewStream)(nil) + +type mockHandlerFactoryWithNewStream struct { + remote.ClientTransHandlerFactory +} + +func (m mockHandlerFactoryWithNewStream) NewStreamTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { + return nil, nil +} + +func Test_kClient_newStreamClientTransHandler(t *testing.T) { + t.Run("not-implemented-NewStream", func(t *testing.T) { + kc := &kClient{ + opt: &client.Options{ + RemoteOpt: &remote.ClientOption{ + CliHandlerFactory: &mockHandlerFactory{}, + }, + }, + } + _, err := kc.newStreamClientTransHandler() + test.Assert(t, err != nil, err) // not implemented + }) + t.Run("implemented-NewStream", func(t *testing.T) { + kc := &kClient{ + opt: &client.Options{ + RemoteOpt: &remote.ClientOption{ + CliHandlerFactory: &mockHandlerFactoryWithNewStream{}, + }, + }, + } + _, err := kc.newStreamClientTransHandler() + test.Assert(t, err == nil) + }) +} diff --git a/client/streamclient/ttheader_option.go b/client/streamclient/ttheader_option.go new file mode 100644 index 0000000000..d93cb49973 --- /dev/null +++ b/client/streamclient/ttheader_option.go @@ -0,0 +1,44 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamclient + +import ( + "fmt" + + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/pkg/utils" +) + +// WithTTHeaderStreamingWaitMetaFrame instructs Kitex client to wait for meta frame +func WithTTHeaderStreamingWaitMetaFrame(wait bool) Option { + return Option{ + F: func(o *client.Options, di *utils.Slice) { + di.Push(fmt.Sprintf("WithTTHeaderStreamingWaitMetaFrame(%v)", wait)) + o.RemoteOpt.TTHeaderStreamingWaitMetaFrame = wait + }, + } +} + +// WithTTHeaderStreamingGRPCCompatible instructs Kitex client to send grpc style metadata to the server +func WithTTHeaderStreamingGRPCCompatible() Option { + return Option{ + F: func(o *client.Options, di *utils.Slice) { + di.Push("WithTTHeaderStreamingGRPCCompatible()") + o.RemoteOpt.TTHeaderStreamingGRPCCompatible = true + }, + } +} diff --git a/internal/mocks/transhandlerclient.go b/internal/mocks/transhandlerclient.go index fd64d8b009..fcad357ad9 100644 --- a/internal/mocks/transhandlerclient.go +++ b/internal/mocks/transhandlerclient.go @@ -20,28 +20,53 @@ import ( "context" "net" + mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" + kitex "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streaming" +) + +var ( + _ remote.ClientStreamAllocator = (*MockCliTransHandler)(nil) + _ remote.ClientTransHandlerFactory = (*mockCliTransHandlerFactory)(nil) + _ remote.ClientStreamTransHandlerFactory = (*mockCliTransHandlerFactory)(nil) + DefaultNewStreamFunc func( + ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter, + ) streaming.Stream = nil + + DefaultNewStreamTransHandlerFunc func(opt *remote.ClientOption) (remote.ClientTransHandler, error) = nil ) type mockCliTransHandlerFactory struct { - hdlr *MockCliTransHandler + hdlr *mocksremote.MockClientTransHandler + NewStreamTransHandlerFunc func(opt *remote.ClientOption) (remote.ClientTransHandler, error) +} + +func (f *mockCliTransHandlerFactory) NewStreamTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { + if f.NewStreamTransHandlerFunc != nil { + return f.NewStreamTransHandlerFunc(opt) + } + if DefaultNewStreamTransHandlerFunc != nil { + return DefaultNewStreamTransHandlerFunc(opt) + } + panic("NewStreamTransHandlerFunc is not available in mocks.mockCliTransHandlerFactory") } // NewMockCliTransHandlerFactory . -func NewMockCliTransHandlerFactory(hdrl *MockCliTransHandler) remote.ClientTransHandlerFactory { - return &mockCliTransHandlerFactory{hdrl} +func NewMockCliTransHandlerFactory(handler *mocksremote.MockClientTransHandler) remote.ClientTransHandlerFactory { + return &mockCliTransHandlerFactory{ + hdlr: handler, + } } func (f *mockCliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { - f.hdlr.opt = opt return f.hdlr, nil } // MockCliTransHandler . type MockCliTransHandler struct { - opt *remote.ClientOption transPipe *remote.TransPipeline WriteFunc func(ctx context.Context, conn net.Conn, send remote.Message) (nctx context.Context, err error) @@ -49,6 +74,18 @@ type MockCliTransHandler struct { ReadFunc func(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) OnMessageFunc func(ctx context.Context, args, result remote.Message) (context.Context, error) + + NewStreamFunc func(ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter) (streaming.Stream, error) +} + +func (t *MockCliTransHandler) NewStream(ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter) (streaming.Stream, error) { + if t.NewStreamFunc != nil { + return t.NewStreamFunc(ctx, svcInfo, conn, handler) + } + if DefaultNewStreamFunc != nil { + return DefaultNewStreamFunc(ctx, svcInfo, conn, handler), nil + } + panic("NewStreamFunc is not available in mocks.MockCliTransHandler") } // Write implements the remote.TransHandler interface. diff --git a/internal/server/remote_option.go b/internal/server/remote_option.go index a814df152c..6ecc82d8a7 100644 --- a/internal/server/remote_option.go +++ b/internal/server/remote_option.go @@ -31,8 +31,14 @@ import ( func newServerRemoteOption() *remote.ServerOption { return &remote.ServerOption{ - TransServerFactory: netpoll.NewTransServerFactory(), - SvrHandlerFactory: detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()), + TransServerFactory: netpoll.NewTransServerFactory(), + SvrHandlerFactory: detection.NewSvrTransHandlerFactory( + // default transHandler + netpoll.NewSvrTransHandlerFactory(), + // detectable transHandlers + netpoll.NewTTHeaderStreamingSvrTransHandlerFactory(), + nphttp2.NewSvrTransHandlerFactory(), + ), Codec: codec.NewDefaultCodec(), Address: defaultAddress, ExitWaitTime: defaultExitWaitTime, diff --git a/internal/server/remote_option_windows.go b/internal/server/remote_option_windows.go index 79fced6f97..b60ed7b654 100644 --- a/internal/server/remote_option_windows.go +++ b/internal/server/remote_option_windows.go @@ -31,8 +31,14 @@ import ( func newServerRemoteOption() *remote.ServerOption { return &remote.ServerOption{ - TransServerFactory: gonet.NewTransServerFactory(), - SvrHandlerFactory: detection.NewSvrTransHandlerFactory(gonet.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()), + TransServerFactory: gonet.NewTransServerFactory(), + SvrHandlerFactory: detection.NewSvrTransHandlerFactory( + // default transHandler + gonet.NewSvrTransHandlerFactory(), + // detectable transHandlers + gonet.NewTTHeaderStreamingSvrTransHandlerFactory(), + nphttp2.NewSvrTransHandlerFactory(), + ), Codec: codec.NewDefaultCodec(), Address: defaultAddress, ExitWaitTime: defaultExitWaitTime, diff --git a/pkg/remote/trans/nphttp2/stream_middleware.go b/pkg/endpoint/stream_middleware.go similarity index 72% rename from pkg/remote/trans/nphttp2/stream_middleware.go rename to pkg/endpoint/stream_middleware.go index 4523aeb1fe..580130bccb 100644 --- a/pkg/remote/trans/nphttp2/stream_middleware.go +++ b/pkg/endpoint/stream_middleware.go @@ -14,10 +14,9 @@ * limitations under the License. */ -package nphttp2 +package endpoint import ( - ep "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/streaming" ) @@ -25,11 +24,12 @@ import ( type streamWithMiddleware struct { streaming.Stream - recvEndpoint ep.RecvEndpoint - sendEndpoint ep.SendEndpoint + recvEndpoint RecvEndpoint + sendEndpoint SendEndpoint } -func newStreamWithMiddleware(st streaming.Stream, recv ep.RecvEndpoint, send ep.SendEndpoint) *streamWithMiddleware { +// NewStreamWithMiddleware creates a new Stream with recv/send middleware support +func NewStreamWithMiddleware(st streaming.Stream, recv RecvEndpoint, send SendEndpoint) streaming.Stream { return &streamWithMiddleware{ Stream: st, recvEndpoint: recv, @@ -37,10 +37,12 @@ func newStreamWithMiddleware(st streaming.Stream, recv ep.RecvEndpoint, send ep. } } +// RecvMsg implements the streaming.Stream interface with recv middlewares func (s *streamWithMiddleware) RecvMsg(m interface{}) error { return s.recvEndpoint(s.Stream, m) } +// SendMsg implements the streaming.Stream interface with send middlewares func (s *streamWithMiddleware) SendMsg(m interface{}) error { return s.sendEndpoint(s.Stream, m) } diff --git a/pkg/generic/http_test/generic_test.go b/pkg/generic/http_test/generic_test.go index 246934e0c9..82dd0e3f05 100644 --- a/pkg/generic/http_test/generic_test.go +++ b/pkg/generic/http_test/generic_test.go @@ -141,13 +141,15 @@ func testThriftNormalBinaryEcho(t *testing.T) { // write: dynamicgo (amd64 && go1.16), fallback (arm || !go1.16) // read: dynamicgo - cli = initThriftClientByIDL(t, transport.PurePayload, addr, "./idl/binary_echo.thrift", opts, false, true) - resp, err = cli.GenericCall(context.Background(), "", customReq, callopt.WithRPCTimeout(100*time.Second)) - test.Assert(t, err == nil, err) - gr, ok = resp.(*generic.HTTPResponse) - test.Assert(t, ok) - test.Assert(t, reflect.DeepEqual(gjson.Get(string(gr.RawBody), "msg").String(), base64.StdEncoding.EncodeToString([]byte(mockMyMsg))), gjson.Get(string(gr.RawBody), "msg").String()) - test.Assert(t, reflect.DeepEqual(gjson.Get(string(gr.RawBody), "num").String(), "0"), gjson.Get(string(gr.RawBody), "num").String()) + // TODO: this test does not make sense since currently DynamicGo requires TTHeader/Framed for PayloadSize + // TODO: As we now have SkipDecoder, the compatibility of DynamicGo can be improved based on that. + // cli = initThriftClientByIDL(t, transport.PurePayload, addr, "./idl/binary_echo.thrift", opts, false, true) + // resp, err = cli.GenericCall(context.Background(), "", customReq, callopt.WithRPCTimeout(100*time.Second)) + // test.Assert(t, err == nil, err) + // gr, ok = resp.(*generic.HTTPResponse) + // test.Assert(t, ok) + // test.Assert(t, reflect.DeepEqual(gjson.Get(string(gr.RawBody), "msg").String(), base64.StdEncoding.EncodeToString([]byte(mockMyMsg))), gjson.Get(string(gr.RawBody), "msg").String()) + // test.Assert(t, reflect.DeepEqual(gjson.Get(string(gr.RawBody), "num").String(), "0"), gjson.Get(string(gr.RawBody), "num").String()) // write: dynamicgo (amd64 && go1.16), fallback (arm || !go1.16) // read: fallback diff --git a/pkg/generic/httppbthrift_codec_test.go b/pkg/generic/httppbthrift_codec_test.go index cbe1b67b2e..7f19f337ac 100644 --- a/pkg/generic/httppbthrift_codec_test.go +++ b/pkg/generic/httppbthrift_codec_test.go @@ -37,7 +37,7 @@ func TestFromHTTPPbRequest(t *testing.T) { mockey.Mock(ioutil.ReadAll).Return([]byte("123"), nil).Build() hreq, err := FromHTTPPbRequest(req) test.Assert(t, err == nil) - test.Assert(t, reflect.DeepEqual(hreq.RawBody, []byte("123")), string(hreq.RawBody)) + test.Assert(t, reflect.DeepEqual(hreq.RawBody, []byte("123")), "should run with `-gcflags=\"all=-N -l\"`") test.Assert(t, hreq.GetMethod() == "POST") test.Assert(t, hreq.GetPath() == "/far/boo") }) diff --git a/pkg/protocol/bthrift/binary_test.go b/pkg/protocol/bthrift/binary_test.go index 2e42b2defd..7f47dff6ad 100644 --- a/pkg/protocol/bthrift/binary_test.go +++ b/pkg/protocol/bthrift/binary_test.go @@ -26,7 +26,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" - internalnetpoll "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll/bytebuf" ) // TestWriteMessageEnd test binary WriteMessageEnd function @@ -319,7 +319,7 @@ func TestWriteStringNocopy(t *testing.T) { buf := make([]byte, 128) exceptWs := "0000000c6d657373616765426567696e" exceptSize := 16 - out := internalnetpoll.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(0)) + out := bytebuf.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(0)) nw, _ := out.(remote.NocopyWrite) wn := Binary.WriteStringNocopy(buf, nw, "messageBegin") ws := fmt.Sprintf("%x", buf[:wn]) @@ -332,7 +332,7 @@ func TestWriteBinaryNocopy(t *testing.T) { buf := make([]byte, 128) exceptWs := "0000000c6d657373616765426567696e" exceptSize := 16 - out := internalnetpoll.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(0)) + out := bytebuf.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(0)) nw, _ := out.(remote.NocopyWrite) wn := Binary.WriteBinaryNocopy(buf, nw, []byte("messageBegin")) ws := fmt.Sprintf("%x", buf[:wn]) diff --git a/pkg/remote/codec/bytebuf_util.go b/pkg/remote/codec/bytebuf_util.go index aac9d2d186..51f90a599a 100644 --- a/pkg/remote/codec/bytebuf_util.go +++ b/pkg/remote/codec/bytebuf_util.go @@ -175,3 +175,15 @@ func ReadString2BLen(bytes []byte, off int) (string, int, error) { buf := bytes[off : off+strLen] return string(buf), int(length) + 2, nil } + +// WriteAll ... +func WriteAll(writer func([]byte) (int, error), buf []byte) error { + for idx := 0; idx < len(buf); { + n, err := writer(buf[idx:]) + if err != nil { + return err + } + idx += n + } + return nil +} diff --git a/pkg/remote/codec/default_codec.go b/pkg/remote/codec/default_codec.go index a4d29c4523..248e868d79 100644 --- a/pkg/remote/codec/default_codec.go +++ b/pkg/remote/codec/default_codec.go @@ -55,6 +55,10 @@ var ( _ remote.MetaDecoder = (*defaultCodec)(nil) ) +func TTHeaderCodec() ttHeader { + return ttHeaderCodec +} + // NewDefaultCodec creates the default protocol sniffing codec supporting thrift and protobuf. func NewDefaultCodec() remote.Codec { // No size limit by default diff --git a/pkg/remote/codec/default_codec_test.go b/pkg/remote/codec/default_codec_test.go index a1a2509671..b11b51e34c 100644 --- a/pkg/remote/codec/default_codec_test.go +++ b/pkg/remote/codec/default_codec_test.go @@ -30,7 +30,7 @@ import ( mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" - netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll/bytebuf" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/transport" @@ -49,7 +49,7 @@ var transportBuffers = []struct { { Name: "NetpollBuffer", NewBuffer: func() remote.ByteBuffer { - return netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) + return bytebuf.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) }, }, } diff --git a/pkg/remote/codec/grpc/grpc.go b/pkg/remote/codec/grpc/grpc.go index 0b417d2606..1720d9121b 100644 --- a/pkg/remote/codec/grpc/grpc.go +++ b/pkg/remote/codec/grpc/grpc.go @@ -43,13 +43,10 @@ type marshaler interface { Size() int } -type protobufV2MsgCodec interface { - XXX_Unmarshal(b []byte) error - XXX_Marshal(b []byte, deterministic bool) ([]byte, error) -} - type grpcCodec struct { - ThriftCodec remote.PayloadCodec + ThriftCodec remote.PayloadCodec + thriftStructCodec remote.StructCodec + protobufStructCodec remote.StructCodec } type CodecOption func(c *grpcCodec) @@ -67,7 +64,17 @@ func NewGRPCCodec(opts ...CodecOption) remote.Codec { opt(codec) } if !thrift.IsThriftCodec(codec.ThriftCodec) { - codec.ThriftCodec = thrift.NewThriftCodec() + if c, err := remote.GetPayloadCodecByCodecType(serviceinfo.Thrift); err == nil { + codec.ThriftCodec = c + } else { + codec.ThriftCodec = thrift.NewThriftCodec() + } + } + codec.thriftStructCodec = codec.ThriftCodec.(remote.StructCodec) + if c, err := remote.GetStructCodecByCodecType(serviceinfo.Protobuf); err == nil { + codec.protobufStructCodec = c + } else { + codec.protobufStructCodec = protobuf.NewProtobufCodec().(remote.StructCodec) } return codec } @@ -99,7 +106,7 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo switch message.ProtocolInfo().CodecType { case serviceinfo.Thrift: - payload, err = thrift.MarshalThriftData(ctx, c.ThriftCodec, message.Data()) + payload, err = c.thriftStructCodec.Serialize(ctx, message.Data()) case serviceinfo.Protobuf: switch t := message.Data().(type) { case fastpb.Writer: @@ -126,7 +133,7 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo if _, err = t.MarshalTo(payload); err != nil { return err } - case protobufV2MsgCodec: + case protobuf.ProtobufV2MsgCodec: payload, err = t.XXX_Marshal(nil, true) case proto.Message: payload, err = proto.Marshal(t) @@ -171,32 +178,11 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot return err } message.SetPayloadLen(len(d)) - data := message.Data() switch message.ProtocolInfo().CodecType { case serviceinfo.Thrift: - return thrift.UnmarshalThriftData(ctx, c.ThriftCodec, "", d, message.Data()) + return c.thriftStructCodec.Deserialize(ctx, message.Data(), d) case serviceinfo.Protobuf: - if t, ok := data.(fastpb.Reader); ok { - if len(d) == 0 { - // if all fields of a struct is default value, data will be nil - // In the implementation of fastpb, if data is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. - // So, when data is nil, use default protobuf unmarshal method to decode the struct. - // todo: fix fastpb - } else { - _, err = fastpb.ReadMessage(d, fastpb.SkipTypeCheck, t) - return err - } - } - switch t := data.(type) { - case protobufV2MsgCodec: - return t.XXX_Unmarshal(d) - case proto.Message: - return proto.Unmarshal(d, t) - case protobuf.ProtobufMsgCodec: - return t.Unmarshal(d) - default: - return ErrInvalidPayload - } + return c.protobufStructCodec.Deserialize(ctx, message.Data(), d) default: return ErrInvalidPayload } diff --git a/pkg/remote/codec/grpc/grpc_test.go b/pkg/remote/codec/grpc/grpc_test.go new file mode 100644 index 0000000000..0e74de6465 --- /dev/null +++ b/pkg/remote/codec/grpc/grpc_test.go @@ -0,0 +1,167 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "context" + "encoding/binary" + "reflect" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +func TestNewGRPCCodec(t *testing.T) { + t.Run("default", func(t *testing.T) { + g := NewGRPCCodec().(*grpcCodec) + thriftCodec := g.ThriftCodec + codec := reflect.ValueOf(thriftCodec).Elem().Field(0).Interface().(thrift.CodecType) + test.Assert(t, codec == thrift.FastReadWrite) + test.Assert(t, g.thriftStructCodec != nil) + test.Assert(t, g.protobufStructCodec != nil) + }) + + t.Run("with-config", func(t *testing.T) { + codecWithConfig := WithThriftCodec(thrift.NewThriftCodecWithConfig(thrift.FastRead)) + g := NewGRPCCodec(codecWithConfig).(*grpcCodec) + thriftCodec := g.ThriftCodec + codec := reflect.ValueOf(thriftCodec).Elem().Field(0).Interface().(thrift.CodecType) + test.Assert(t, codec == thrift.FastRead) + }) +} + +var _ thrift.ThriftMsgFastCodec = (*thriftStruct)(nil) + +type thriftStruct struct { + A int +} + +func (t *thriftStruct) BLength() int { + return 4 +} + +// FastWriteNocopy not real implementation of thrift encoding, just for test +func (t *thriftStruct) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { + binary.BigEndian.PutUint32(buf, uint32(t.A)) + return 4 +} + +// FastRead not real implementation of thrift decoding, just for test +func (t *thriftStruct) FastRead(buf []byte) (int, error) { + t.A = int(binary.BigEndian.Uint32(buf)) + return 4, nil +} + +type pbStruct struct { + A int +} + +var _ protobuf.ProtobufV2MsgCodec = (*pbStruct)(nil) + +// not real implementation of pb decoding, just for test +func (p *pbStruct) XXX_Unmarshal(b []byte) error { + p.A = int(binary.BigEndian.Uint32(b)) + return nil +} + +// not real implementation of pb encoding, just for test +func (p *pbStruct) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(p.A)) + return buf, nil +} + +func Test_grpcCodec_Encode(t *testing.T) { + t.Run("thrift", func(t *testing.T) { + g := NewGRPCCodec().(*grpcCodec) + data := &thriftStruct{A: 1} + stat := rpcinfo.NewRPCStats() + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + msg := remote.NewMessage(data, nil, ri, remote.Call, remote.Client) + out := remote.NewWriterBuffer(1024) + + err := g.Encode(context.Background(), msg, out) + + test.Assert(t, err == nil, err) + buf, err := out.Bytes() + test.Assert(t, err == nil, err) + test.Assert(t, len(buf) > 0) + }) + + t.Run("protobuf", func(t *testing.T) { + g := NewGRPCCodec().(*grpcCodec) + data := &pbStruct{} + stat := rpcinfo.NewRPCStats() + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + msg := remote.NewMessage(data, nil, ri, remote.Call, remote.Client) + msg.SetProtocolInfo(remote.ProtocolInfo{CodecType: serviceinfo.Protobuf}) + out := remote.NewWriterBuffer(1024) + + err := g.Encode(context.Background(), msg, out) + + test.Assert(t, err == nil, err) + buf, err := out.Bytes() + test.Assert(t, err == nil, err) + test.Assert(t, len(buf) > 0, len(buf)) + }) +} + +func Test_grpcCodec_Decode(t *testing.T) { + t.Run("thrift", func(t *testing.T) { + g := NewGRPCCodec().(*grpcCodec) + buf := make([]byte, 9) + buf[0] = 0 // frame: not compressed + binary.BigEndian.PutUint32(buf[1:], 4) // frame: data size + binary.BigEndian.PutUint32(buf[5:], 1) // frame: data + data := &thriftStruct{} + stat := rpcinfo.NewRPCStats() + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + msg := remote.NewMessage(data, nil, ri, remote.Call, remote.Client) + in := remote.NewReaderBuffer(buf) + + err := g.Decode(context.Background(), msg, in) + + test.Assert(t, err == nil, err) + got := msg.Data().(*thriftStruct) + test.Assert(t, got.A == 1, got) + }) + t.Run("protobuf", func(t *testing.T) { + g := NewGRPCCodec().(*grpcCodec) + buf := make([]byte, 9) + buf[0] = 0 // frame: not compressed + binary.BigEndian.PutUint32(buf[1:], 4) // frame: data size + binary.BigEndian.PutUint32(buf[5:], 1) // frame: data + data := &pbStruct{} + stat := rpcinfo.NewRPCStats() + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + msg := remote.NewMessage(data, nil, ri, remote.Call, remote.Client) + msg.SetProtocolInfo(remote.ProtocolInfo{CodecType: serviceinfo.Protobuf}) + in := remote.NewReaderBuffer(buf) + + err := g.Decode(context.Background(), msg, in) + + test.Assert(t, err == nil, err) + got := msg.Data().(*pbStruct) + test.Assert(t, got.A == 1, got) + }) +} diff --git a/pkg/remote/codec/header_codec.go b/pkg/remote/codec/header_codec.go index 524dd6a903..459941d41f 100644 --- a/pkg/remote/codec/header_codec.go +++ b/pkg/remote/codec/header_codec.go @@ -69,7 +69,7 @@ import ( const ( // Header Magics // 0 and 16th bits must be 0 to differentiate from framed & unframed - TTHeaderMagic uint32 = 0x10000000 + TTHeaderMagic uint32 = 0x10000000 // magic(2 bytes) + flags(2 bytes) MeshHeaderMagic uint32 = 0xFFAF0000 MeshHeaderLenMask uint32 = 0x0000FFFF @@ -85,6 +85,7 @@ type HeaderFlags uint16 const ( HeaderFlagsKey string = "HeaderFlags" HeaderFlagSupportOutOfOrder HeaderFlags = 0x01 + HeaderFlagsStreaming HeaderFlags = 0b0000_0000_0000_0010 HeaderFlagDuplexReverse HeaderFlags = 0x08 HeaderFlagSASL HeaderFlags = 0x10 ) @@ -93,6 +94,16 @@ const ( TTHeaderMetaSize = 14 ) +const ( + FrameTypeMeta = "1" + FrameTypeHeader = "2" + FrameTypeData = "3" + FrameTypeTrailer = "4" + FrameTypeInvalid = "" + + StrKeyMetaData = "grpc-metadata" +) + // ProtocolID is the wrapped protocol id used in THeader. type ProtocolID uint8 @@ -102,9 +113,22 @@ const ( ProtocolIDThriftCompact ProtocolID = 0x02 // Kitex not support ProtocolIDThriftCompactV2 ProtocolID = 0x03 // Kitex not support ProtocolIDKitexProtobuf ProtocolID = 0x04 + ProtocolIDThriftStruct ProtocolID = 0x10 // TTHeader Streaming: only thrift struct encoded, no magic + ProtocolIDProtobufStruct ProtocolID = 0x11 // TTHeader Streaming: only protobuf struct encoded, no magic + ProtocolIDNotSpecified ProtocolID = 0xff ProtocolIDDefault = ProtocolIDThriftBinary ) +var protocolIDToPayloadCodecMap = map[ProtocolID]serviceinfo.PayloadCodec{ + ProtocolIDThriftBinary: serviceinfo.Thrift, + ProtocolIDThriftCompact: serviceinfo.Thrift, + ProtocolIDThriftCompactV2: serviceinfo.Thrift, + ProtocolIDKitexProtobuf: serviceinfo.Protobuf, + ProtocolIDThriftStruct: serviceinfo.Thrift, + ProtocolIDProtobufStruct: serviceinfo.Protobuf, + ProtocolIDNotSpecified: serviceinfo.NotSpecified, +} + type InfoIDType uint8 // uint8 const ( @@ -114,8 +138,25 @@ const ( InfoIDACLToken InfoIDType = 0x11 ) +var ( + errNotTTHeaderStreaming = perrors.NewProtocolErrorWithMsg("not TTHeader Streaming protocol") + errNotTTHeader = perrors.NewProtocolErrorWithMsg("not TTHeader protocol") +) + type ttHeader struct{} +// EncodeHeader encodes the message as TTHeader format and writes it to out. +// A prefixed 4-bytes is returned for the caller to write the frame size (after calculated the payload size). +func (t ttHeader) EncodeHeader(ctx context.Context, message remote.Message, out remote.ByteBuffer) (totalLenField []byte, err error) { + return t.encode(ctx, message, out) +} + +// DecodeHeader decodes the message from TTHeader format from in (including the 4-bytes frame size prefix). +// PayloadLen is set to message for the caller to read the following payload. +func (t ttHeader) DecodeHeader(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { + return t.decode(ctx, message, in) +} + func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (totalLenField []byte, err error) { // 1. header meta var headerMeta []byte @@ -126,12 +167,17 @@ func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote totalLenField = headerMeta[0:4] headerInfoSizeField := headerMeta[12:14] - binary.BigEndian.PutUint32(headerMeta[4:8], TTHeaderMagic+uint32(getFlags(message))) + magic := TTHeaderMagic | uint32(getFlags(message)) + if message.MessageType() == remote.Stream { + magic |= uint32(HeaderFlagsStreaming) + } + binary.BigEndian.PutUint32(headerMeta[4:8], magic) binary.BigEndian.PutUint32(headerMeta[8:12], uint32(message.RPCInfo().Invocation().SeqID())) + protocolID := getMessageProtocolID(message) var transformIDs []uint8 // transformIDs not support TODO compress // 2. header info, malloc and write - if err = WriteByte(byte(getProtocolID(message.ProtocolInfo())), out); err != nil { + if err = WriteByte(protocolID, out); err != nil { return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write protocol id failed, %s", err.Error())) } if err = WriteByte(byte(len(transformIDs)), out); err != nil { @@ -161,8 +207,10 @@ func (t ttHeader) decode(ctx context.Context, message remote.Message, in remote. if err != nil { return perrors.NewProtocolError(err) } - if !IsTTHeader(headerMeta) { - return perrors.NewProtocolErrorWithMsg("not TTHeader protocol") + if message.MessageType() == remote.Stream && !IsTTHeaderStreaming(headerMeta) { + return errNotTTHeaderStreaming + } else if !IsTTHeader(headerMeta) { + return errNotTTHeader } totalLen := Bytes2Uint32NoCheck(headerMeta[:Size32]) @@ -184,7 +232,7 @@ func (t ttHeader) decode(ctx context.Context, message remote.Message, in remote. if headerInfo, err = in.Next(int(headerInfoSize)); err != nil { return perrors.NewProtocolError(err) } - if err = checkProtocolID(headerInfo[0], message); err != nil { + if err = checkAndSetProtocolID(headerInfo[0], message); err != nil { return err } hdIdx := 2 @@ -408,6 +456,23 @@ func setFlags(flags uint16, message remote.Message) { } } +// getMessageProtocolID returns the protocol id of the message +// `getProtocolID` always returns ProtocolIDThriftBinary due to historical compatibility issues, +// but if it's a streaming message, it's necessary to distinguish Thrift/Protobuf +func getMessageProtocolID(message remote.Message) byte { + if message.MessageType() == remote.Stream { + switch message.ProtocolInfo().CodecType { + case serviceinfo.Thrift: + return byte(ProtocolIDThriftStruct) + case serviceinfo.Protobuf: + return byte(ProtocolIDProtobufStruct) + default: + return byte(ProtocolIDNotSpecified) + } + } + return byte(getProtocolID(message.ProtocolInfo())) +} + // protoID just for ttheader func getProtocolID(pi remote.ProtocolInfo) ProtocolID { switch pi.CodecType { @@ -422,16 +487,21 @@ func getProtocolID(pi remote.ProtocolInfo) ProtocolID { return ProtocolIDDefault } -// protoID just for ttheader -func checkProtocolID(protoID uint8, message remote.Message) error { - switch protoID { - case uint8(ProtocolIDThriftBinary): - case uint8(ProtocolIDKitexProtobuf): - case uint8(ProtocolIDThriftCompactV2): - // just for compatibility - default: - return fmt.Errorf("unsupported ProtocolID[%d]", protoID) +// checkAndSetProtocolID checks the validity of the protocol id and set the payloadCodec of the message +// For a streaming message, it's necessary to distinguish Thrift/Protobuf +func checkAndSetProtocolID(protoID uint8, message remote.Message) error { + payloadCodec, exists := protocolIDToPayloadCodecMap[ProtocolID(protoID)] + if !exists { + if message.MessageType() != remote.Stream { + return fmt.Errorf("unsupported ProtocolID[%d]", protoID) + } else { + payloadCodec = serviceinfo.NotSpecified + } } + message.SetProtocolInfo(remote.ProtocolInfo{ + TransProto: message.ProtocolInfo().TransProto, + CodecType: payloadCodec, + }) return nil } diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index 0c8066955e..5b18b198eb 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -19,6 +19,7 @@ package codec import ( "context" "encoding/binary" + "errors" "net" "testing" @@ -521,3 +522,123 @@ func TestHeaderFlags(t *testing.T) { hfs = getFlags(msg) test.Assert(t, hfs == HeaderFlagSupportOutOfOrder, hfs) } + +func Test_ttHeader_encode(t *testing.T) { + t.Run("flag:streaming", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + out := remote.NewWriterBuffer(256) + + _, err := ttHeaderCodec.encode(context.Background(), msg, out) + test.Assert(t, err == nil, err) + + bytes, err := out.Bytes() + test.Assert(t, err == nil, err) + flags := binary.BigEndian.Uint16(bytes[Size32+Size16:]) + test.Assert(t, flags|uint16(HeaderFlagsStreaming) != 0, flags) + }) + t.Run("protocol-id:streaming-thrift", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) + out := remote.NewWriterBuffer(256) + _, err := ttHeaderCodec.encode(context.Background(), msg, out) + test.Assert(t, err == nil, err) + + bytes, err := out.Bytes() + test.Assert(t, err == nil, err) + offsetProtocolID := Size32 /*FramedSize*/ + Size32 /*Magic+Flags*/ + Size32 /*SeqID*/ + Size16 /*HeaderSize*/ + protocolID := bytes[offsetProtocolID] + test.Assert(t, protocolID == byte(ProtocolIDThriftStruct), protocolID) + }) +} + +func Test_ttHeader_decode(t *testing.T) { + t.Run("magic-validation:ping-pong", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Call, remote.Client) + buf := make([]byte, 14) + err := ttHeaderCodec.decode(context.Background(), msg, remote.NewReaderBuffer(buf)) + test.Assert(t, errors.Is(err, errNotTTHeader), err) + }) + t.Run("magic-validation:streaming", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + buf := make([]byte, 14) + err := ttHeaderCodec.decode(context.Background(), msg, remote.NewReaderBuffer(buf)) + test.Assert(t, errors.Is(err, errNotTTHeaderStreaming), err) + }) + t.Run("protocol-id:protobuf", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Protobuf)) + out := remote.NewWriterBuffer(256) + _, err := ttHeaderCodec.encode(context.Background(), msg, out) + test.Assert(t, err == nil, err) + bytes, _ := out.Bytes() + + rMsg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + err = ttHeaderCodec.decode(context.Background(), rMsg, remote.NewReaderBuffer(bytes)) + + test.Assert(t, err == nil, err) + test.Assert(t, rMsg.ProtocolInfo().CodecType == serviceinfo.Protobuf) + }) +} + +func Test_getMessageProtocolID(t *testing.T) { + t.Run("Call", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Call, remote.Client) + protocolID := getMessageProtocolID(msg) + test.Assert(t, protocolID == byte(ProtocolIDDefault)) + }) + t.Run("Stream", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + protocolID := getMessageProtocolID(msg) + test.Assert(t, protocolID == byte(ProtocolIDThriftStruct)) + }) + t.Run("Stream:protobuf", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Protobuf)) + protocolID := getMessageProtocolID(msg) + test.Assert(t, protocolID == byte(ProtocolIDProtobufStruct)) + }) + t.Run("Stream:NotSpecified", func(t *testing.T) { + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.NotSpecified)) + protocolID := getMessageProtocolID(msg) + test.Assert(t, protocolID == byte(ProtocolIDNotSpecified)) + }) +} + +func Test_checkAndSetProtocolID(t *testing.T) { + t.Run("pingpong:thrift", func(t *testing.T) { + protocolID := byte(ProtocolIDThriftBinary) + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Call, remote.Client) + err := checkAndSetProtocolID(protocolID, msg) + test.Assert(t, err == nil, err) + test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift) + }) + t.Run("pingpong:protobuf", func(t *testing.T) { + protocolID := byte(ProtocolIDProtobufStruct) + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Call, remote.Client) + err := checkAndSetProtocolID(protocolID, msg) + test.Assert(t, err == nil, err) + test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf) + }) + t.Run("pingpong:unknown", func(t *testing.T) { + protocolID := byte(ProtocolIDNotSpecified) - 1 // unknown + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Call, remote.Client) + err := checkAndSetProtocolID(protocolID, msg) + test.Assert(t, err != nil, err) + }) + t.Run("stream:thrift", func(t *testing.T) { + protocolID := byte(ProtocolIDThriftStruct) + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + err := checkAndSetProtocolID(protocolID, msg) + test.Assert(t, err == nil, err) + test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift) + }) + t.Run("stream:protobuf", func(t *testing.T) { + protocolID := byte(ProtocolIDProtobufStruct) + msg := remote.NewMessage(nil, mocks.ServiceInfo(), mockCliRPCInfo, remote.Stream, remote.Client) + msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Protobuf)) + err := checkAndSetProtocolID(protocolID, msg) + test.Assert(t, err == nil, err) + test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf) + }) +} diff --git a/pkg/remote/codec/protobuf/protobuf.go b/pkg/remote/codec/protobuf/protobuf.go index a5a11e0aff..efe90e4672 100644 --- a/pkg/remote/codec/protobuf/protobuf.go +++ b/pkg/remote/codec/protobuf/protobuf.go @@ -46,6 +46,11 @@ const ( metaInfoFixLen = 8 ) +var ( + _ remote.PayloadCodec = (*protobufCodec)(nil) + _ remote.StructCodec = (*protobufCodec)(nil) +) + // NewProtobufCodec ... func NewProtobufCodec() remote.PayloadCodec { return &protobufCodec{} @@ -235,6 +240,11 @@ type ProtobufMsgCodec interface { Unmarshal(in []byte) error } +type ProtobufV2MsgCodec interface { + XXX_Unmarshal(b []byte) error + XXX_Marshal(b []byte, deterministic bool) ([]byte, error) +} + func getValidData(methodName string, message remote.Message) (interface{}, error) { if err := codec.NewDataIfNeeded(methodName, message); err != nil { return nil, err diff --git a/pkg/remote/codec/protobuf/protobuf_struct_codec.go b/pkg/remote/codec/protobuf/protobuf_struct_codec.go new file mode 100644 index 0000000000..efeafd9f4d --- /dev/null +++ b/pkg/remote/codec/protobuf/protobuf_struct_codec.go @@ -0,0 +1,90 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protobuf + +import ( + "context" + "errors" + + "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/fastpb" + "google.golang.org/protobuf/proto" +) + +// This is for streaming transports including grpc and ttheader + +var ( + ErrInvalidPayload = errors.New("protobuf: invalid payload") + ErrDataChanged = errors.New("protobuf: data changed between size calculation and serialization") + ErrDataUnmarshalable = errors.New("protobuf: data unmarshalable") +) + +type marshaler interface { + MarshalTo(data []byte) (n int, err error) + Size() int +} + +func (c protobufCodec) Serialize(ctx context.Context, data interface{}) (payload []byte, err error) { + switch t := data.(type) { + case fastpb.Writer: + size := t.Size() + payload = mcache.Malloc(size) + if n := t.FastWrite(payload); n != size { + return nil, ErrDataChanged + } + return payload, nil + case marshaler: + size := t.Size() + payload = mcache.Malloc(size) + _, err = t.MarshalTo(payload) + return payload, err + case ProtobufMsgCodec: + return t.Marshal(nil) + case ProtobufV2MsgCodec: + return t.XXX_Marshal(nil, true) + case proto.Message: + return proto.Marshal(t) + default: + return nil, ErrDataUnmarshalable + } +} + +// Deserialize implements the remote.StructCodec interface. +// It deserializes the protobuf struct in to message.Data(). +func (c protobufCodec) Deserialize(ctx context.Context, data interface{}, payload []byte) (err error) { + if t, ok := data.(fastpb.Reader); ok { + if len(payload) == 0 { + // if all fields of a struct is default value, data will be nil + // In the implementation of fastpb, if data is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. + // So, when data is nil, use default protobuf unmarshal method to decode the struct. + // todo: fix fastpb + } else { + _, err = fastpb.ReadMessage(payload, fastpb.SkipTypeCheck, t) + return err + } + } + switch t := data.(type) { + case ProtobufV2MsgCodec: + return t.XXX_Unmarshal(payload) + case proto.Message: + return proto.Unmarshal(payload, t) + case ProtobufMsgCodec: + return t.Unmarshal(payload) + default: + return ErrInvalidPayload + } +} diff --git a/pkg/remote/codec/protobuf/protobuf_struct_codec_test.go b/pkg/remote/codec/protobuf/protobuf_struct_codec_test.go new file mode 100644 index 0000000000..de72345942 --- /dev/null +++ b/pkg/remote/codec/protobuf/protobuf_struct_codec_test.go @@ -0,0 +1,188 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protobuf + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/cloudwego/fastpb" + "google.golang.org/protobuf/encoding/protowire" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/generic/httppb_test/idl" +) + +// The following structs are just for test, not conforming to the protobuf specification +var ( + _ fastpb.Writer = (*fastStruct)(nil) + _ fastpb.Reader = (*fastStruct)(nil) + _ marshaler = (*marshalerStruct)(nil) + _ ProtobufMsgCodec = (*pbMsgStruct)(nil) + _ ProtobufV2MsgCodec = (*pbMsgV2Struct)(nil) +) + +type pbMsgV2Struct struct { + buf []byte +} + +func (p *pbMsgV2Struct) XXX_Unmarshal(b []byte) error { + p.buf = make([]byte, len(b)) + copy(p.buf, b) + return nil +} + +func (p *pbMsgV2Struct) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return p.buf, nil +} + +type pbMsgStruct struct { + buf []byte +} + +func (p *pbMsgStruct) Marshal(out []byte) ([]byte, error) { + return p.buf, nil +} + +func (p *pbMsgStruct) Unmarshal(in []byte) error { + p.buf = make([]byte, len(in)) + copy(p.buf, in) + return nil +} + +type fastStruct struct { + size int + buf []byte +} + +func (p *fastStruct) FastRead(buf []byte, _type int8, number int32) (n int, err error) { + p.buf = make([]byte, len(buf)) + return copy(buf, p.buf), nil +} + +func (p *fastStruct) Size() (n int) { + return p.size +} + +func (p *fastStruct) FastWrite(buf []byte) (n int) { + return copy(buf, p.buf) +} + +type marshalerStruct struct { + size int + buf []byte +} + +func (p *marshalerStruct) Size() (n int) { + return p.size +} + +func (p *marshalerStruct) MarshalTo(data []byte) (n int, err error) { + p.buf = make([]byte, len(data)) + return copy(data, p.buf), nil +} + +func Test_protobufCodec_Serialize(t *testing.T) { + c := &protobufCodec{} + t.Run("fastpb.Writer", func(t *testing.T) { + data := &fastStruct{ + size: 4, + buf: []byte{1, 2, 3, 4}, + } + + buf, err := c.Serialize(context.Background(), data) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(buf, data.buf), buf) + }) + t.Run("marshaler", func(t *testing.T) { + data := &marshalerStruct{ + size: 4, + buf: []byte{1, 2, 3, 4}, + } + buf, err := c.Serialize(context.Background(), data) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(buf, data.buf), buf) + }) + t.Run("ProtobufMsgCodec", func(t *testing.T) { + data := &pbMsgStruct{ + buf: []byte{1, 2, 3, 4}, + } + buf, err := c.Serialize(context.Background(), data) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(buf, data.buf), buf) + }) + t.Run("ProtobufV2MsgCodec", func(t *testing.T) { + data := &idl.Message{Tiny: 1} + buf, err := c.Serialize(context.Background(), data) + test.Assert(t, err == nil, err) + test.Assert(t, len(buf) != 0, buf) + }) + t.Run("invalid-data-type", func(t *testing.T) { + data := struct{}{} + _, err := c.Serialize(context.Background(), data) + test.Assert(t, errors.Is(err, ErrDataUnmarshalable), err) + }) +} + +func Test_protobufCodec_Deserialize(t *testing.T) { + c := &protobufCodec{} + t.Run("fastpb.Reader:non-zero-payload", func(t *testing.T) { + buf := []byte{byte(protowire.EncodeTag(1, 1)), 1, 2, 3, 4} + data := &fastStruct{} + err := c.Deserialize(context.Background(), data, buf) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(data.buf, buf[1:]), data.buf) + }) + t.Run("fastpb.Reader:zero-payload", func(t *testing.T) { + buf := []byte{} + data := &fastStruct{} + err := c.Deserialize(context.Background(), data, buf) + test.Assert(t, errors.Is(err, ErrInvalidPayload), err) // no fallback + }) + t.Run("ProtobufV2MsgCodec", func(t *testing.T) { + buf := []byte{1, 2, 3, 4} + data := &pbMsgV2Struct{} + err := c.Deserialize(context.Background(), data, buf) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(data.buf, buf), data.buf) + }) + t.Run("proto.Message", func(t *testing.T) { + data := &idl.Message{Tiny: 1} + buf, err := c.Serialize(context.Background(), data) + test.Assert(t, err == nil, err) + test.Assert(t, len(buf) != 0, buf) + + rData := &idl.Message{} + err = c.Deserialize(context.Background(), rData, buf) + test.Assert(t, err == nil, err) + test.Assert(t, rData.Tiny == 1, rData.Tiny) + }) + t.Run("ProtobufMsgCodec", func(t *testing.T) { + buf := []byte{1, 2, 3, 4} + data := &pbMsgStruct{} + err := c.Deserialize(context.Background(), data, buf) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(data.buf, buf), data.buf) + }) + t.Run("invalid-data-type", func(t *testing.T) { + data := struct{}{} + err := c.Deserialize(context.Background(), data, []byte{}) + test.Assert(t, errors.Is(err, ErrInvalidPayload), err) + }) +} diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index eb9771e965..1d98e0033a 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -45,6 +45,9 @@ const ( ) var ( + _ remote.PayloadCodec = (*thriftCodec)(nil) + _ remote.StructCodec = (*thriftCodec)(nil) + defaultCodec = NewThriftCodec().(*thriftCodec) errEncodeMismatchMsgType = remote.NewTransErrorWithMsg(remote.InvalidProtocol, @@ -64,7 +67,7 @@ func IsThriftCodec(c remote.PayloadCodec) bool { return ok } -// NewThriftFrugalCodec creates the thrift binary codec powered by frugal. +// NewThriftCodecWithConfig creates the thrift binary codec powered by frugal. // Eg: xxxservice.NewServer(handler, server.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastWrite | thrift.FastRead))) func NewThriftCodecWithConfig(c CodecType) remote.PayloadCodec { return &thriftCodec{c} diff --git a/pkg/remote/codec/thrift/thrift_struct_codec.go b/pkg/remote/codec/thrift/thrift_struct_codec.go new file mode 100644 index 0000000000..d91dd8f3a9 --- /dev/null +++ b/pkg/remote/codec/thrift/thrift_struct_codec.go @@ -0,0 +1,41 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/perrors" +) + +// Serialize implements the remote.StructCodec interface. +// It's only for encoding thrift struct to bytes +func (c thriftCodec) Serialize(ctx context.Context, data interface{}) ([]byte, error) { + return c.marshalThriftData(ctx, data) +} + +// Deserialize implements the remote.StructCodec interface. +// It's only for decoding bytes to thrift struct (`method`, `msgType` and `seqID` are not required here) +func (c thriftCodec) Deserialize(ctx context.Context, data interface{}, payload []byte) (err error) { + tProt := NewBinaryProtocol(remote.NewReaderBuffer(payload)) + if err = c.unmarshalThriftData(ctx, tProt, "", data, len(payload)); err != nil { + return perrors.NewProtocolError(err) + } + tProt.Recycle() + return nil +} diff --git a/pkg/remote/codec/thrift/thrift_test.go b/pkg/remote/codec/thrift/thrift_test.go index 3d4029b881..eeb440ae5a 100644 --- a/pkg/remote/codec/thrift/thrift_test.go +++ b/pkg/remote/codec/thrift/thrift_test.go @@ -28,7 +28,7 @@ import ( mt "github.com/cloudwego/kitex/internal/mocks/thrift/fast" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" - netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll/bytebuf" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/transport" @@ -51,7 +51,7 @@ var ( { Name: "NetpollBuffer", NewBuffer: func() remote.ByteBuffer { - return netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) + return bytebuf.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) }, }, } diff --git a/pkg/remote/codec/ttheader/stream_codec.go b/pkg/remote/codec/ttheader/stream_codec.go new file mode 100644 index 0000000000..371e0a726d --- /dev/null +++ b/pkg/remote/codec/ttheader/stream_codec.go @@ -0,0 +1,113 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheader + +import ( + "context" + "encoding/binary" + "errors" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +var ( + _ remote.Codec = (*streamCodec)(nil) + + ErrInvalidPayload = errors.New("ttheader streaming: invalid payload") + + ttheaderCodec = codec.TTHeaderCodec() +) + +// streamCodec implements the ttheader codec +type streamCodec struct{} + +// NewStreamCodec ttheader codec construction +func NewStreamCodec() remote.Codec { + return &streamCodec{} +} + +// Encode implements the remote.Codec interface +func (c *streamCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) { + var payload []byte + defer func() { + // record send size, even when err != nil (0 is recorded to the lastSendSize) + if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil { + rpcStats.IncrSendSize(uint64(len(payload))) + } + }() + + if codec.MessageFrameType(message) == codec.FrameTypeData { + if payload, err = c.marshalData(ctx, message); err != nil { + return err + } + } + + framedSizeBuf, err := ttheaderCodec.EncodeHeader(ctx, message, out) + if err != nil { + return err + } + headerSize := out.MallocLen() - codec.Size32 + binary.BigEndian.PutUint32(framedSizeBuf, uint32(headerSize+len(payload))) + return codec.WriteAll(out.WriteBinary, payload) +} + +func (c *streamCodec) marshalData(ctx context.Context, message remote.Message) ([]byte, error) { + if message.Data() == nil { // for non-data frames: MetaFrame, HeaderFrame, TrailerFrame(w/o err) + return nil, nil + } + structCodec, err := remote.GetStructCodec(message) + if err != nil { + return nil, ErrInvalidPayload + } + return structCodec.Serialize(ctx, message.Data()) +} + +// Decode implements the remote.Codec interface +func (c *streamCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { + var payload []byte + defer func() { + if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil { + // record recv size, even when err != nil (0 is recorded to the lastRecvSize) + rpcStats.IncrRecvSize(uint64(len(payload))) + } + }() + if err = ttheaderCodec.DecodeHeader(ctx, message, in); err != nil { + return err + } + if codec.MessageFrameType(message) != codec.FrameTypeData { + return nil + } + if message.Data() == nil { + return in.Skip(message.PayloadLen()) // necessary for discarding dirty frames + } + // only data frame need to decode payload into message.Data() + if payload, err = in.Next(message.PayloadLen()); err != nil { + return err + } + structCodec, err := remote.GetStructCodec(message) + if err != nil { + return ErrInvalidPayload + } + return structCodec.Deserialize(ctx, message.Data(), payload) +} + +// Name implements the remote.Codec interface +func (c *streamCodec) Name() string { + return "ttheader streaming" +} diff --git a/pkg/remote/codec/ttheader/stream_codec_test.go b/pkg/remote/codec/ttheader/stream_codec_test.go new file mode 100644 index 0000000000..15e53f65d6 --- /dev/null +++ b/pkg/remote/codec/ttheader/stream_codec_test.go @@ -0,0 +1,150 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheader_test + +// another package to avoid cycle import + +import ( + "context" + "net" + "testing" + + "github.com/cloudwego/kitex/internal/test" + ftest "github.com/cloudwego/kitex/pkg/protocol/bthrift/test/kitex_gen/test" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + "github.com/cloudwego/kitex/pkg/remote/codec/ttheader" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +func newRPCInfo() rpcinfo.RPCInfo { + fromAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:9000") + toAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:9001") + from := rpcinfo.NewEndpointInfo("from-svc", "from-method", fromAddr, nil) + to := rpcinfo.NewEndpointInfo("to-svc", "to-method", toAddr, nil) + ink := rpcinfo.NewInvocation("idl-service-name", "to-method") + config := rpcinfo.NewRPCConfig() + stats := rpcinfo.NewRPCStats() + return rpcinfo.NewRPCInfo(from, to, ink, config, stats) +} + +func init() { + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) +} + +type thriftStruct struct { + N int +} + +func Test_streamCodec_Encode(t *testing.T) { + c := ttheader.NewStreamCodec() + t.Run("encode-data-failed", func(t *testing.T) { + msg := remote.NewMessage(&thriftStruct{}, nil, newRPCInfo(), remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + out := remote.NewWriterBuffer(1024) + err := c.Encode(context.Background(), msg, out) + test.Assert(t, err != nil, err) + }) + t.Run("encode-header-failed", func(t *testing.T) { + msg := remote.NewMessage(&thriftStruct{}, nil, newRPCInfo(), remote.Stream, remote.Server) + msg.TransInfo().TransStrInfo()["xxx"] = string(make([]byte, 65536)) // longer than TTHeader size limit + out := remote.NewWriterBuffer(1024) + err := c.Encode(context.Background(), msg, out) + test.Assert(t, err != nil, err) + test.Assert(t, msg.RPCInfo().Stats().LastSendSize() == 0) + }) +} + +func Test_streamCodec_Decode(t *testing.T) { + c := ttheader.NewStreamCodec() + t.Run("decode-header-failed", func(t *testing.T) { + msg := remote.NewMessage(&thriftStruct{}, nil, newRPCInfo(), remote.Stream, remote.Server) + in := remote.NewReaderBuffer(nil) + err := c.Decode(context.Background(), msg, in) + test.Assert(t, err != nil, err) + }) + t.Run("decode-non-data-frame", func(t *testing.T) { + wMsg := remote.NewMessage(nil, nil, newRPCInfo(), remote.Stream, remote.Server) + wMsg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + out := remote.NewWriterBuffer(1024) + err := c.Encode(context.Background(), wMsg, out) + test.Assert(t, err == nil, err) + test.Assert(t, wMsg.RPCInfo().Stats().LastSendSize() == 0) + + rMsg := remote.NewMessage(&thriftStruct{}, nil, newRPCInfo(), remote.Stream, remote.Server) + buf, _ := out.Bytes() + in := remote.NewReaderBuffer(buf) + err = c.Decode(context.Background(), rMsg, in) + test.Assert(t, err == nil, err) + test.Assert(t, rMsg.RPCInfo().Stats().LastRecvSize() == 0) + test.Assert(t, rMsg.TransInfo().TransIntInfo()[transmeta.FrameType] == codec.FrameTypeHeader) + }) + + t.Run("decode-data-frame", func(t *testing.T) { + wMsg := remote.NewMessage(&ftest.Local{L: 1}, nil, newRPCInfo(), remote.Stream, remote.Server) + wMsg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + out := remote.NewWriterBuffer(1024) + err := c.Encode(context.Background(), wMsg, out) + test.Assert(t, err == nil, err) + test.Assert(t, wMsg.RPCInfo().Stats().LastSendSize() > 0) + + rData := &ftest.Local{} + rMsg := remote.NewMessage(rData, nil, newRPCInfo(), remote.Stream, remote.Server) + buf, _ := out.Bytes() + in := remote.NewReaderBuffer(buf) + err = c.Decode(context.Background(), rMsg, in) + test.Assert(t, err == nil, err) + test.Assert(t, rMsg.RPCInfo().Stats().LastRecvSize() > 0) + test.Assert(t, rData.L == 1, rData) + test.Assert(t, rMsg.TransInfo().TransIntInfo()[transmeta.FrameType] == codec.FrameTypeData) + }) + + t.Run("decode-data-frame-skip-payload", func(t *testing.T) { + wMsg := remote.NewMessage(&ftest.Local{L: 1}, nil, newRPCInfo(), remote.Stream, remote.Server) + wMsg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + out := remote.NewWriterBuffer(1024) + err := c.Encode(context.Background(), wMsg, out) + test.Assert(t, err == nil, err) + test.Assert(t, wMsg.RPCInfo().Stats().LastSendSize() > 0) + + rMsg := remote.NewMessage(nil, nil, newRPCInfo(), remote.Stream, remote.Server) + buf, _ := out.Bytes() + in := remote.NewReaderBuffer(buf) + err = c.Decode(context.Background(), rMsg, in) + test.Assert(t, err == nil, err) + test.Assert(t, rMsg.RPCInfo().Stats().LastRecvSize() == 0) + test.Assert(t, rMsg.TransInfo().TransIntInfo()[transmeta.FrameType] == codec.FrameTypeData) + }) + + t.Run("decode-data-frame-payload-short-read", func(t *testing.T) { + wMsg := remote.NewMessage(&ftest.Local{L: 1}, nil, newRPCInfo(), remote.Stream, remote.Server) + wMsg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + out := remote.NewWriterBuffer(1024) + err := c.Encode(context.Background(), wMsg, out) + test.Assert(t, err == nil, err) + test.Assert(t, wMsg.RPCInfo().Stats().LastSendSize() > 0) + + rMsg := remote.NewMessage(nil, nil, newRPCInfo(), remote.Stream, remote.Server) + buf, _ := out.Bytes() + in := remote.NewReaderBuffer(buf[0 : len(buf)-1]) + err = c.Decode(context.Background(), rMsg, in) + test.Assert(t, err != nil, err) + }) +} diff --git a/pkg/remote/codec/util.go b/pkg/remote/codec/util.go index 78b13347f2..e137a5432a 100644 --- a/pkg/remote/codec/util.go +++ b/pkg/remote/codec/util.go @@ -17,11 +17,14 @@ package codec import ( + "encoding/binary" "errors" "fmt" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) const ( @@ -115,3 +118,42 @@ func NewDataIfNeeded(method string, message remote.Message) error { } return remote.NewTransErrorWithMsg(remote.InternalError, "message data for codec is nil") } + +// ProtocolIDToPayloadCodec converts ProtocolID to PayloadCodec +func ProtocolIDToPayloadCodec(p ProtocolID) serviceinfo.PayloadCodec { + switch p { + case ProtocolIDThriftBinary, ProtocolIDThriftCompact, ProtocolIDThriftCompactV2: + return serviceinfo.Thrift + case ProtocolIDKitexProtobuf: + return serviceinfo.Protobuf + default: + return serviceinfo.NotSpecified + } +} + +// PayloadCodecToProtocolID converts PayloadCodec to ProtocolID +func PayloadCodecToProtocolID(p serviceinfo.PayloadCodec) ProtocolID { + switch p { + case serviceinfo.Thrift: + return ProtocolIDThriftBinary + case serviceinfo.Protobuf: + return ProtocolIDKitexProtobuf + default: + return ProtocolIDNotSpecified + } +} + +// MessageFrameType returns the frame type of the message; the default value is "trailer" if the key is not found. +func MessageFrameType(message remote.Message) string { + if ft, exists := message.TransInfo().TransIntInfo()[transmeta.FrameType]; exists { + return ft + } + // to be compatible with old ttheader frames containing TApplicationException + return FrameTypeTrailer +} + +// IsTTHeaderStreaming checks whether the bytes match ttheader streaming (magic + flags with streaming bit set) +func IsTTHeaderStreaming(bytes []byte) bool { + return binary.BigEndian.Uint16(bytes[Size32:]) == uint16(TTHeaderMagic>>16) && + binary.BigEndian.Uint16(bytes[Size32+2:])&uint16(HeaderFlagsStreaming) != 0 +} diff --git a/pkg/remote/default_bytebuf.go b/pkg/remote/default_bytebuf.go index fe7b3f3a94..18f80ebf79 100644 --- a/pkg/remote/default_bytebuf.go +++ b/pkg/remote/default_bytebuf.go @@ -60,7 +60,10 @@ type defaultByteBuffer struct { status int } -var _ ByteBuffer = &defaultByteBuffer{} +var ( + _ ByteBuffer = (*defaultByteBuffer)(nil) + _ FrameWrite = (*defaultByteBuffer)(nil) +) func newDefaultByteBuffer() interface{} { return &defaultByteBuffer{} @@ -98,6 +101,18 @@ func newReaderWriterByteBuffer(estimatedLength int) ByteBuffer { return bytebuf } +// WriteHeader implements FrameWrite +func (b *defaultByteBuffer) WriteHeader(buf []byte) (err error) { + _, err = b.WriteBinary(buf) + return +} + +// WriteData implements FrameWrite +func (b *defaultByteBuffer) WriteData(buf []byte) (err error) { + _, err = b.WriteBinary(buf) + return +} + // Next reads n bytes sequentially, returns the original address. func (b *defaultByteBuffer) Next(n int) (buf []byte, err error) { if b.status&BitReadable == 0 { diff --git a/pkg/remote/default_bytebuf_test.go b/pkg/remote/default_bytebuf_test.go index b3fcfb527e..83558bc3c7 100644 --- a/pkg/remote/default_bytebuf_test.go +++ b/pkg/remote/default_bytebuf_test.go @@ -155,3 +155,21 @@ func checkUnreadable(t *testing.T, buf ByteBuffer) { test.Assert(t, err != nil) test.Assert(t, n == -1, n) } + +func Test_defaultByteBuffer_WriteHeader(t *testing.T) { + b := newWriterByteBuffer(16) + fw := b.(FrameWrite) + err := fw.WriteHeader([]byte("hello")) + test.Assert(t, err == nil, err) + got, _ := b.Bytes() + test.Assert(t, string(got) == "hello", string(got)) +} + +func Test_defaultByteBuffer_WriteData(t *testing.T) { + b := newWriterByteBuffer(16) + fw := b.(FrameWrite) + err := fw.WriteData([]byte("hello")) + test.Assert(t, err == nil, err) + got, _ := b.Bytes() + test.Assert(t, string(got) == "hello", string(got)) +} diff --git a/pkg/remote/option.go b/pkg/remote/option.go index afe0f567b9..5dd62da10f 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -35,7 +35,8 @@ type Option struct { Inbounds []InboundHandler - StreamingMetaHandlers []StreamingMetaHandler + StreamingMetaHandlers []StreamingMetaHandler + TTHeaderFrameMetaHandler FrameMetaHandler } // PrependBoundHandler adds a BoundHandler to the head. @@ -144,4 +145,7 @@ type ClientOption struct { Option EnableConnPoolReporter bool + + TTHeaderStreamingWaitMetaFrame bool + TTHeaderStreamingGRPCCompatible bool // not enabled by default for performance issue } diff --git a/pkg/remote/payload_codec.go b/pkg/remote/payload_codec.go index 00f88fc072..321f92f911 100644 --- a/pkg/remote/payload_codec.go +++ b/pkg/remote/payload_codec.go @@ -32,12 +32,23 @@ type PayloadCodec interface { Name() string } +// StructCodec is used to serialize and deserialize a given struct. +type StructCodec interface { + Serialize(ctx context.Context, data interface{}) ([]byte, error) + Deserialize(ctx context.Context, data interface{}, payload []byte) error +} + // GetPayloadCodec gets desired payload codec from message. func GetPayloadCodec(message Message) (PayloadCodec, error) { if message.PayloadCodec() != nil { return message.PayloadCodec(), nil } ct := message.ProtocolInfo().CodecType + return GetPayloadCodecByCodecType(ct) +} + +// GetPayloadCodecByCodecType gets payload codec by codecType. +func GetPayloadCodecByCodecType(ct serviceinfo.PayloadCodec) (PayloadCodec, error) { pc := payloadCodecs[ct] if pc == nil { return nil, fmt.Errorf("payload codec not found with codecType=%v", ct) @@ -49,3 +60,23 @@ func GetPayloadCodec(message Message) (PayloadCodec, error) { func PutPayloadCode(name serviceinfo.PayloadCodec, v PayloadCodec) { payloadCodecs[name] = v } + +// GetStructCodec gets desired payload struct codec from message. +func GetStructCodec(message Message) (StructCodec, error) { + if codec, err := GetPayloadCodec(message); err == nil { + if structCodec := codec.(StructCodec); structCodec != nil { + return structCodec, nil + } + } + return nil, fmt.Errorf("payload struct codec not found with codecType=%v", message.ProtocolInfo().CodecType) +} + +// GetStructCodecByCodecType gets the struct codec by codecType. +func GetStructCodecByCodecType(ct serviceinfo.PayloadCodec) (StructCodec, error) { + if codec, err := GetPayloadCodecByCodecType(ct); err == nil { + if structCodec := codec.(StructCodec); structCodec != nil { + return structCodec, nil + } + } + return nil, fmt.Errorf("payload struct codec not found with codecType=%v", ct) +} diff --git a/pkg/remote/remotecli/stream.go b/pkg/remote/remotecli/stream.go index 344a2d87b3..acf90de249 100644 --- a/pkg/remote/remotecli/stream.go +++ b/pkg/remote/remotecli/stream.go @@ -19,13 +19,32 @@ package remotecli import ( "context" + "fmt" + "runtime/debug" + "time" + "github.com/bytedance/gopkg/util/gopool" + + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/transport" ) +type StreamCleaner struct { + Async bool + Timeout time.Duration + Clean func(st streaming.Stream) error +} + +var registeredCleaner = map[transport.Protocol]*StreamCleaner{} + +// RegisterCleaner registers a StreamCleaner associated with the given protocol +func RegisterCleaner(protocol transport.Protocol, cleaner *StreamCleaner) { + registeredCleaner[protocol] = cleaner +} + // NewStream create a client side stream func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTransHandler, opt *remote.ClientOption) (streaming.Stream, *StreamConnManager, error) { cm := NewConnWrapper(opt.ConnPool) @@ -40,7 +59,18 @@ func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTra if err != nil { return nil, nil, err } - return nphttp2.NewStream(ctx, opt.SvcInfo, rawConn, handler), &StreamConnManager{cm}, nil + + allocator, ok := handler.(remote.ClientStreamAllocator) + if !ok { + err = fmt.Errorf("remote.ClientStreamAllocator is not implemented by %T", handler) + return nil, nil, remote.NewTransError(remote.InternalError, err) + } + + st, err := allocator.NewStream(ctx, opt.SvcInfo, rawConn, handler) + if err != nil { + return nil, nil, err + } + return st, &StreamConnManager{cm}, nil } // NewStreamConnManager returns a new StreamConnManager @@ -54,6 +84,57 @@ type StreamConnManager struct { } // ReleaseConn releases the raw connection of the stream -func (scm *StreamConnManager) ReleaseConn(err error, ri rpcinfo.RPCInfo) { - scm.ConnReleaser.ReleaseConn(err, ri) +func (scm *StreamConnManager) ReleaseConn(st streaming.Stream, err error, ri rpcinfo.RPCInfo) { + if err != nil { + // for non-nil error, just pass it to the ConnReleaser + scm.ConnReleaser.ReleaseConn(err, ri) + return + } + + cleaner, exists := registeredCleaner[ri.Config().TransportProtocol()] + if !exists { + scm.ConnReleaser.ReleaseConn(nil, ri) + return + } + + if cleaner.Async { + scm.AsyncCleanAndRelease(cleaner, st, ri) + } else { + cleanErr := cleaner.Clean(st) + scm.ConnReleaser.ReleaseConn(cleanErr, ri) + } +} + +// AsyncCleanAndRelease releases the raw connection of the stream asynchronously if an async cleaner is registered +func (scm *StreamConnManager) AsyncCleanAndRelease(cleaner *StreamCleaner, st streaming.Stream, ri rpcinfo.RPCInfo) { + gopool.Go(func() { + var finalErr error + defer func() { + if panicInfo := recover(); panicInfo != nil { + finalErr = fmt.Errorf("panic=%v, stack=%s", panicInfo, debug.Stack()) + } + if finalErr != nil { + klog.Debugf("cleaner failed: %v", finalErr) + } + scm.ConnReleaser.ReleaseConn(finalErr, ri) + }() + finished := make(chan error, 1) + t := time.NewTimer(cleaner.Timeout) + gopool.Go(func() { + var cleanErr error + defer func() { + if panicInfo := recover(); panicInfo != nil { + cleanErr = fmt.Errorf("panic=%v, stack=%s", panicInfo, debug.Stack()) + } + finished <- cleanErr + }() + cleanErr = cleaner.Clean(st) + }) + select { + case <-t.C: + finalErr = context.DeadlineExceeded + case cleanErr := <-finished: + finalErr = cleanErr + } + }) } diff --git a/pkg/remote/remotecli/stream_test.go b/pkg/remote/remotecli/stream_test.go index 196a14dc89..efd554b0b6 100644 --- a/pkg/remote/remotecli/stream_test.go +++ b/pkg/remote/remotecli/stream_test.go @@ -18,7 +18,10 @@ package remotecli import ( "context" + "errors" + "net" "testing" + "time" "github.com/golang/mock/gomock" @@ -26,9 +29,13 @@ import ( mock_remote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/rpcinfo" + kitex "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streaming" "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/kitex/transport" ) func TestNewStream(t *testing.T) { @@ -40,6 +47,9 @@ func TestNewStream(t *testing.T) { ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) hdlr := &mocks.MockCliTransHandler{} + hdlr.NewStreamFunc = func(ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter) (streaming.Stream, error) { + return nphttp2.NewStream(ctx, svcInfo, conn, handler), nil + } opt := newMockOption(ctrl, addr) opt.Option = remote.Option{ @@ -61,5 +71,164 @@ func TestStreamConnManagerReleaseConn(t *testing.T) { scr := NewStreamConnManager(cr) test.Assert(t, scr != nil) test.Assert(t, scr.ConnReleaser == cr) - scr.ReleaseConn(nil, nil) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), nil) + scr.ReleaseConn(nil, nil, ri) +} + +type mockConnReleaser struct { + releaseConn func(err error, ri rpcinfo.RPCInfo) +} + +func (m *mockConnReleaser) ReleaseConn(err error, ri rpcinfo.RPCInfo) { + if m.releaseConn != nil { + m.releaseConn(err, ri) + } +} + +func TestStreamConnManager_AsyncCleanAndRelease(t *testing.T) { + t.Run("finish-without-error", func(t *testing.T) { + cleaner := &StreamCleaner{ + Async: true, + Timeout: time.Second, + Clean: func(st streaming.Stream) error { + return nil + }, + } + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + scm.AsyncCleanAndRelease(cleaner, nil, nil) + got := <-result + test.Assert(t, got == nil) + }) + t.Run("finish-with-error", func(t *testing.T) { + cleaner := &StreamCleaner{ + Async: true, + Timeout: time.Second, + Clean: func(st streaming.Stream) error { + return context.DeadlineExceeded + }, + } + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + scm.AsyncCleanAndRelease(cleaner, nil, nil) + got := <-result + test.Assert(t, got != nil) + }) + t.Run("cleaner-timeout", func(t *testing.T) { + cleaner := &StreamCleaner{ + Async: true, + Timeout: time.Millisecond * 10, + Clean: func(st streaming.Stream) error { + time.Sleep(time.Millisecond * 20) + return nil + }, + } + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + scm.AsyncCleanAndRelease(cleaner, nil, nil) + got := <-result + test.Assert(t, got != nil) + }) + t.Run("cleaner-panic", func(t *testing.T) { + cleaner := &StreamCleaner{ + Async: true, + Timeout: time.Second, + Clean: func(st streaming.Stream) error { + panic("cleaner panic") + }, + } + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + scm.AsyncCleanAndRelease(cleaner, nil, nil) + got := <-result + test.Assert(t, got != nil) + }) +} + +func TestStreamConnManager_ReleaseConn(t *testing.T) { + t.Run("non-nil-error", func(t *testing.T) { + err := errors.New("error") + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + scm.ReleaseConn(nil, err, nil) + got := <-result + test.Assert(t, errors.Is(got, err)) + }) + t.Run("nil-error-without-registered-cleaner", func(t *testing.T) { + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + cfg := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.Framed) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, cfg, nil) + scm.ReleaseConn(nil, nil, ri) + got := <-result + test.Assert(t, got == nil) + }) + t.Run("nil-error-with-registered-sync-cleaner", func(t *testing.T) { + err := errors.New("error") + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + cfg := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.Framed) + RegisterCleaner(transport.Framed, &StreamCleaner{ + Async: false, + Clean: func(st streaming.Stream) error { + return err + }, + }) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, cfg, nil) + scm.ReleaseConn(nil, nil, ri) + got := <-result + test.Assert(t, errors.Is(got, err)) + }) + t.Run("nil-error-with-registered-async-cleaner", func(t *testing.T) { + err := errors.New("error") + result := make(chan error, 1) + scm := NewStreamConnManager(&mockConnReleaser{ + releaseConn: func(err error, ri rpcinfo.RPCInfo) { + result <- err + }, + }) + cfg := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.Framed) + RegisterCleaner(transport.Framed, &StreamCleaner{ + Async: true, + Timeout: time.Second, + Clean: func(st streaming.Stream) error { + return err + }, + }) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, cfg, nil) + scm.ReleaseConn(nil, nil, ri) + got := <-result + test.Assert(t, errors.Is(got, err)) + }) } diff --git a/pkg/remote/trans/gonet/client_handler.go b/pkg/remote/trans/gonet/client_handler.go index b9133d41cd..fb63c1c845 100644 --- a/pkg/remote/trans/gonet/client_handler.go +++ b/pkg/remote/trans/gonet/client_handler.go @@ -19,6 +19,12 @@ package gonet import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans" + "github.com/cloudwego/kitex/pkg/remote/trans/ttheaderstreaming" +) + +var ( + _ remote.ClientTransHandlerFactory = (*cliTransHandlerFactory)(nil) + _ remote.ClientStreamTransHandlerFactory = (*cliTransHandlerFactory)(nil) ) type cliTransHandlerFactory struct{} @@ -36,3 +42,7 @@ func (f *cliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remo func newCliTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { return trans.NewDefaultCliTransHandler(opt, NewGonetExtension()) } + +func (f *cliTransHandlerFactory) NewStreamTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { + return ttheaderstreaming.NewCliTransHandler(opt, NewGonetExtension()) +} diff --git a/pkg/remote/trans/gonet/server_handler.go b/pkg/remote/trans/gonet/server_handler.go index a3cb93ce95..472b2e5b6f 100644 --- a/pkg/remote/trans/gonet/server_handler.go +++ b/pkg/remote/trans/gonet/server_handler.go @@ -19,6 +19,7 @@ package gonet import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans" + "github.com/cloudwego/kitex/pkg/remote/trans/ttheaderstreaming" ) type svrTransHandlerFactory struct{} @@ -36,3 +37,13 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo func newSvrTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { return trans.NewDefaultSvrTransHandler(opt, NewGonetExtension()) } + +type ttheaderStreamingSvrTransHandlerFactory struct{} + +func (s *ttheaderStreamingSvrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { + return ttheaderstreaming.NewSvrTransHandler(opt, NewGonetExtension()) +} + +func NewTTHeaderStreamingSvrTransHandlerFactory() remote.ServerTransHandlerFactory { + return &ttheaderStreamingSvrTransHandlerFactory{} +} diff --git a/pkg/remote/trans/netpoll/bytebuf.go b/pkg/remote/trans/netpoll/bytebuf.go index 736060da97..6b31416e28 100644 --- a/pkg/remote/trans/netpoll/bytebuf.go +++ b/pkg/remote/trans/netpoll/bytebuf.go @@ -17,259 +17,26 @@ package netpoll import ( - "errors" - "io" - "sync" - "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll/bytebuf" ) -var bytebufPool sync.Pool - -func init() { - bytebufPool.New = newNetpollByteBuffer -} +// The original implementation is moved to the sub package bytebuf. +// This file now only contains existing exported symbols to keep the compatibility. // NewReaderByteBuffer creates a new remote.ByteBuffer using the given netpoll.ZeroCopyReader. func NewReaderByteBuffer(r netpoll.Reader) remote.ByteBuffer { - bytebuf := bytebufPool.Get().(*netpollByteBuffer) - bytebuf.reader = r - // TODO(wangtieju): fix me when netpoll support netpoll.Reader - // and LinkBuffer not support io.Reader, type assertion would fail when r is from NewBuffer - if ir, ok := r.(io.Reader); ok { - bytebuf.ioReader = ir - } - bytebuf.status = remote.BitReadable - bytebuf.readSize = 0 - return bytebuf + return bytebuf.NewReaderByteBuffer(r) } // NewWriterByteBuffer creates a new remote.ByteBuffer using the given netpoll.ZeroCopyWriter. func NewWriterByteBuffer(w netpoll.Writer) remote.ByteBuffer { - bytebuf := bytebufPool.Get().(*netpollByteBuffer) - bytebuf.writer = w - // TODO(wangtieju): fix me when netpoll support netpoll.Writer - // and LinkBuffer not support io.Reader, type assertion would fail when w is from NewBuffer - if iw, ok := w.(io.Writer); ok { - bytebuf.ioWriter = iw - } - bytebuf.status = remote.BitWritable - return bytebuf + return bytebuf.NewWriterByteBuffer(w) } // NewReaderWriterByteBuffer creates a new remote.ByteBuffer using the given netpoll.ZeroCopyReadWriter. func NewReaderWriterByteBuffer(rw netpoll.ReadWriter) remote.ByteBuffer { - bytebuf := bytebufPool.Get().(*netpollByteBuffer) - bytebuf.writer = rw - bytebuf.reader = rw - // TODO(wangtieju): fix me when netpoll support netpoll.ReadWriter - // and LinkBuffer not support io.ReadWriter, type assertion would fail when rw is from NewBuffer - if irw, ok := rw.(io.ReadWriter); ok { - bytebuf.ioReader = irw - bytebuf.ioWriter = irw - } - bytebuf.status = remote.BitWritable | remote.BitReadable - return bytebuf -} - -func newNetpollByteBuffer() interface{} { - return &netpollByteBuffer{} -} - -type netpollByteBuffer struct { - writer netpoll.Writer - reader netpoll.Reader - ioReader io.Reader - ioWriter io.Writer - status int - readSize int -} - -var _ remote.ByteBuffer = &netpollByteBuffer{} - -// Next reads n bytes sequentially, returns the original address. -func (b *netpollByteBuffer) Next(n int) (p []byte, err error) { - if b.status&remote.BitReadable == 0 { - return nil, errors.New("unreadable buffer, cannot support Next") - } - if p, err = b.reader.Next(n); err == nil { - b.readSize += n - } - return -} - -// Peek returns the next n bytes without advancing the reader. -func (b *netpollByteBuffer) Peek(n int) (buf []byte, err error) { - if b.status&remote.BitReadable == 0 { - return nil, errors.New("unreadable buffer, cannot support Peek") - } - return b.reader.Peek(n) -} - -// Skip is used to skip n bytes, it's much faster than Next. -// Skip will not cause release. -func (b *netpollByteBuffer) Skip(n int) (err error) { - if b.status&remote.BitReadable == 0 { - return errors.New("unreadable buffer, cannot support Skip") - } - return b.reader.Skip(n) -} - -// ReadableLen returns the total length of readable buffer. -func (b *netpollByteBuffer) ReadableLen() (n int) { - if b.status&remote.BitReadable == 0 { - return -1 - } - return b.reader.Len() -} - -// Read implement io.Reader -func (b *netpollByteBuffer) Read(p []byte) (n int, err error) { - if b.status&remote.BitReadable == 0 { - return -1, errors.New("unreadable buffer, cannot support Read") - } - if b.ioReader != nil { - return b.ioReader.Read(p) - } - return -1, errors.New("ioReader is nil") -} - -// ReadString is a more efficient way to read string than Next. -func (b *netpollByteBuffer) ReadString(n int) (s string, err error) { - if b.status&remote.BitReadable == 0 { - return "", errors.New("unreadable buffer, cannot support ReadString") - } - if s, err = b.reader.ReadString(n); err == nil { - b.readSize += n - } - return -} - -// ReadBinary like ReadString. -// Returns a copy of original buffer. -func (b *netpollByteBuffer) ReadBinary(n int) (p []byte, err error) { - if b.status&remote.BitReadable == 0 { - return p, errors.New("unreadable buffer, cannot support ReadBinary") - } - if p, err = b.reader.ReadBinary(n); err == nil { - b.readSize += n - } - return -} - -// Malloc n bytes sequentially in the writer buffer. -func (b *netpollByteBuffer) Malloc(n int) (buf []byte, err error) { - if b.status&remote.BitWritable == 0 { - return nil, errors.New("unwritable buffer, cannot support Malloc") - } - return b.writer.Malloc(n) -} - -// MallocAck n bytes in the writer buffer. -func (b *netpollByteBuffer) MallocAck(n int) (err error) { - if b.status&remote.BitWritable == 0 { - return errors.New("unwritable buffer, cannot support MallocAck") - } - return b.writer.MallocAck(n) -} - -// MallocLen returns the total length of the buffer malloced. -func (b *netpollByteBuffer) MallocLen() (length int) { - if b.status&remote.BitWritable == 0 { - return -1 - } - return b.writer.MallocLen() -} - -// Write implement io.Writer -func (b *netpollByteBuffer) Write(p []byte) (n int, err error) { - if b.status&remote.BitWritable == 0 { - return -1, errors.New("unwritable buffer, cannot support Write") - } - if b.ioWriter != nil { - return b.ioWriter.Write(p) - } - return -1, errors.New("ioWriter is nil") -} - -// WriteString is a more efficient way to write string, using the unsafe method to convert the string to []byte. -func (b *netpollByteBuffer) WriteString(s string) (n int, err error) { - if b.status&remote.BitWritable == 0 { - return -1, errors.New("unwritable buffer, cannot support WriteString") - } - return b.writer.WriteString(s) -} - -// WriteBinary writes the []byte directly. Callers must guarantee that the []byte doesn't change. -func (b *netpollByteBuffer) WriteBinary(p []byte) (n int, err error) { - if b.status&remote.BitWritable == 0 { - return -1, errors.New("unwritable buffer, cannot support WriteBinary") - } - return b.writer.WriteBinary(p) -} - -// WriteDirect is a way to write []byte without copying, and splits the original buffer. -func (b *netpollByteBuffer) WriteDirect(p []byte, remainCap int) error { - if b.status&remote.BitWritable == 0 { - return errors.New("unwritable buffer, cannot support WriteBinary") - } - return b.writer.WriteDirect(p, remainCap) -} - -// ReadLen returns the size already read. -func (b *netpollByteBuffer) ReadLen() (n int) { - return b.readSize -} - -// Flush writes any malloc data to the underlying io.Writer. -// The malloced buffer must be set correctly. -func (b *netpollByteBuffer) Flush() (err error) { - if b.status&remote.BitWritable == 0 { - return errors.New("unwritable buffer, cannot support Flush") - } - return b.writer.Flush() -} - -// NewBuffer returns a new writable remote.ByteBuffer. -func (b *netpollByteBuffer) NewBuffer() remote.ByteBuffer { - return NewWriterByteBuffer(netpoll.NewLinkBuffer()) -} - -// AppendBuffer appends buf to the original buffer. -func (b *netpollByteBuffer) AppendBuffer(buf remote.ByteBuffer) (err error) { - subBuf := buf.(*netpollByteBuffer) - err = b.writer.Append(subBuf.writer) - buf.Release(nil) - return -} - -// Bytes are not supported in netpoll bytebuf. -func (b *netpollByteBuffer) Bytes() (buf []byte, err error) { - if b.reader != nil { - return b.reader.Peek(b.reader.Len()) - } - return nil, errors.New("method Bytes() not support in netpoll bytebuf") -} - -// Release will free the buffer already read. -// After release, buffer read by Next/Skip/Peek is invalid. -func (b *netpollByteBuffer) Release(e error) (err error) { - if b.reader != nil { - // 重复执行Release nil panic - err = b.reader.Release() - } - b.zero() - bytebufPool.Put(b) - return -} - -func (b *netpollByteBuffer) zero() { - b.writer = nil - b.reader = nil - b.ioReader = nil - b.ioWriter = nil - b.status = 0 - b.readSize = 0 + return bytebuf.NewReaderWriterByteBuffer(rw) } diff --git a/pkg/remote/trans/netpoll/bytebuf/bytebuf.go b/pkg/remote/trans/netpoll/bytebuf/bytebuf.go new file mode 100644 index 0000000000..7998621a21 --- /dev/null +++ b/pkg/remote/trans/netpoll/bytebuf/bytebuf.go @@ -0,0 +1,275 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package bytebuf + +import ( + "errors" + "io" + "sync" + + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/pkg/remote" +) + +var bytebufPool sync.Pool + +func init() { + bytebufPool.New = newNetpollByteBuffer +} + +// NewReaderByteBuffer creates a new remote.ByteBuffer using the given netpoll.ZeroCopyReader. +func NewReaderByteBuffer(r netpoll.Reader) remote.ByteBuffer { + bytebuf := bytebufPool.Get().(*netpollByteBuffer) + bytebuf.reader = r + // TODO(wangtieju): fix me when netpoll support netpoll.Reader + // and LinkBuffer not support io.Reader, type assertion would fail when r is from NewBuffer + if ir, ok := r.(io.Reader); ok { + bytebuf.ioReader = ir + } + bytebuf.status = remote.BitReadable + bytebuf.readSize = 0 + return bytebuf +} + +// NewWriterByteBuffer creates a new remote.ByteBuffer using the given netpoll.ZeroCopyWriter. +func NewWriterByteBuffer(w netpoll.Writer) remote.ByteBuffer { + bytebuf := bytebufPool.Get().(*netpollByteBuffer) + bytebuf.writer = w + // TODO(wangtieju): fix me when netpoll support netpoll.Writer + // and LinkBuffer not support io.Reader, type assertion would fail when w is from NewBuffer + if iw, ok := w.(io.Writer); ok { + bytebuf.ioWriter = iw + } + bytebuf.status = remote.BitWritable + return bytebuf +} + +// NewReaderWriterByteBuffer creates a new remote.ByteBuffer using the given netpoll.ZeroCopyReadWriter. +func NewReaderWriterByteBuffer(rw netpoll.ReadWriter) remote.ByteBuffer { + bytebuf := bytebufPool.Get().(*netpollByteBuffer) + bytebuf.writer = rw + bytebuf.reader = rw + // TODO(wangtieju): fix me when netpoll support netpoll.ReadWriter + // and LinkBuffer not support io.ReadWriter, type assertion would fail when rw is from NewBuffer + if irw, ok := rw.(io.ReadWriter); ok { + bytebuf.ioReader = irw + bytebuf.ioWriter = irw + } + bytebuf.status = remote.BitWritable | remote.BitReadable + return bytebuf +} + +func newNetpollByteBuffer() interface{} { + return &netpollByteBuffer{} +} + +type netpollByteBuffer struct { + writer netpoll.Writer + reader netpoll.Reader + ioReader io.Reader + ioWriter io.Writer + status int + readSize int +} + +var _ remote.ByteBuffer = &netpollByteBuffer{} + +// Next reads n bytes sequentially, returns the original address. +func (b *netpollByteBuffer) Next(n int) (p []byte, err error) { + if b.status&remote.BitReadable == 0 { + return nil, errors.New("unreadable buffer, cannot support Next") + } + if p, err = b.reader.Next(n); err == nil { + b.readSize += n + } + return +} + +// Peek returns the next n bytes without advancing the reader. +func (b *netpollByteBuffer) Peek(n int) (buf []byte, err error) { + if b.status&remote.BitReadable == 0 { + return nil, errors.New("unreadable buffer, cannot support Peek") + } + return b.reader.Peek(n) +} + +// Skip is used to skip n bytes, it's much faster than Next. +// Skip will not cause release. +func (b *netpollByteBuffer) Skip(n int) (err error) { + if b.status&remote.BitReadable == 0 { + return errors.New("unreadable buffer, cannot support Skip") + } + return b.reader.Skip(n) +} + +// ReadableLen returns the total length of readable buffer. +func (b *netpollByteBuffer) ReadableLen() (n int) { + if b.status&remote.BitReadable == 0 { + return -1 + } + return b.reader.Len() +} + +// Read implement io.Reader +func (b *netpollByteBuffer) Read(p []byte) (n int, err error) { + if b.status&remote.BitReadable == 0 { + return -1, errors.New("unreadable buffer, cannot support Read") + } + if b.ioReader != nil { + return b.ioReader.Read(p) + } + return -1, errors.New("ioReader is nil") +} + +// ReadString is a more efficient way to read string than Next. +func (b *netpollByteBuffer) ReadString(n int) (s string, err error) { + if b.status&remote.BitReadable == 0 { + return "", errors.New("unreadable buffer, cannot support ReadString") + } + if s, err = b.reader.ReadString(n); err == nil { + b.readSize += n + } + return +} + +// ReadBinary like ReadString. +// Returns a copy of original buffer. +func (b *netpollByteBuffer) ReadBinary(n int) (p []byte, err error) { + if b.status&remote.BitReadable == 0 { + return p, errors.New("unreadable buffer, cannot support ReadBinary") + } + if p, err = b.reader.ReadBinary(n); err == nil { + b.readSize += n + } + return +} + +// Malloc n bytes sequentially in the writer buffer. +func (b *netpollByteBuffer) Malloc(n int) (buf []byte, err error) { + if b.status&remote.BitWritable == 0 { + return nil, errors.New("unwritable buffer, cannot support Malloc") + } + return b.writer.Malloc(n) +} + +// MallocAck n bytes in the writer buffer. +func (b *netpollByteBuffer) MallocAck(n int) (err error) { + if b.status&remote.BitWritable == 0 { + return errors.New("unwritable buffer, cannot support MallocAck") + } + return b.writer.MallocAck(n) +} + +// MallocLen returns the total length of the buffer malloced. +func (b *netpollByteBuffer) MallocLen() (length int) { + if b.status&remote.BitWritable == 0 { + return -1 + } + return b.writer.MallocLen() +} + +// Write implement io.Writer +func (b *netpollByteBuffer) Write(p []byte) (n int, err error) { + if b.status&remote.BitWritable == 0 { + return -1, errors.New("unwritable buffer, cannot support Write") + } + if b.ioWriter != nil { + return b.ioWriter.Write(p) + } + return -1, errors.New("ioWriter is nil") +} + +// WriteString is a more efficient way to write string, using the unsafe method to convert the string to []byte. +func (b *netpollByteBuffer) WriteString(s string) (n int, err error) { + if b.status&remote.BitWritable == 0 { + return -1, errors.New("unwritable buffer, cannot support WriteString") + } + return b.writer.WriteString(s) +} + +// WriteBinary writes the []byte directly. Callers must guarantee that the []byte doesn't change. +func (b *netpollByteBuffer) WriteBinary(p []byte) (n int, err error) { + if b.status&remote.BitWritable == 0 { + return -1, errors.New("unwritable buffer, cannot support WriteBinary") + } + return b.writer.WriteBinary(p) +} + +// WriteDirect is a way to write []byte without copying, and splits the original buffer. +func (b *netpollByteBuffer) WriteDirect(p []byte, remainCap int) error { + if b.status&remote.BitWritable == 0 { + return errors.New("unwritable buffer, cannot support WriteBinary") + } + return b.writer.WriteDirect(p, remainCap) +} + +// ReadLen returns the size already read. +func (b *netpollByteBuffer) ReadLen() (n int) { + return b.readSize +} + +// Flush writes any malloc data to the underlying io.Writer. +// The malloced buffer must be set correctly. +func (b *netpollByteBuffer) Flush() (err error) { + if b.status&remote.BitWritable == 0 { + return errors.New("unwritable buffer, cannot support Flush") + } + return b.writer.Flush() +} + +// NewBuffer returns a new writable remote.ByteBuffer. +func (b *netpollByteBuffer) NewBuffer() remote.ByteBuffer { + return NewWriterByteBuffer(netpoll.NewLinkBuffer()) +} + +// AppendBuffer appends buf to the original buffer. +func (b *netpollByteBuffer) AppendBuffer(buf remote.ByteBuffer) (err error) { + subBuf := buf.(*netpollByteBuffer) + err = b.writer.Append(subBuf.writer) + buf.Release(nil) + return +} + +// Bytes are not supported in netpoll bytebuf. +func (b *netpollByteBuffer) Bytes() (buf []byte, err error) { + if b.reader != nil { + return b.reader.Peek(b.reader.Len()) + } + return nil, errors.New("method Bytes() not support in netpoll bytebuf") +} + +// Release will free the buffer already read. +// After release, buffer read by Next/Skip/Peek is invalid. +func (b *netpollByteBuffer) Release(e error) (err error) { + if b.reader != nil { + // 重复执行Release nil panic + err = b.reader.Release() + } + b.zero() + bytebufPool.Put(b) + return +} + +func (b *netpollByteBuffer) zero() { + b.writer = nil + b.reader = nil + b.ioReader = nil + b.ioWriter = nil + b.status = 0 + b.readSize = 0 +} diff --git a/pkg/remote/trans/netpoll/bytebuf_test.go b/pkg/remote/trans/netpoll/bytebuf/bytebuf_test.go similarity index 99% rename from pkg/remote/trans/netpoll/bytebuf_test.go rename to pkg/remote/trans/netpoll/bytebuf/bytebuf_test.go index 2642513076..595e7ed36d 100644 --- a/pkg/remote/trans/netpoll/bytebuf_test.go +++ b/pkg/remote/trans/netpoll/bytebuf/bytebuf_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package netpoll +package bytebuf import ( "bufio" diff --git a/pkg/remote/trans/netpoll/client_handler.go b/pkg/remote/trans/netpoll/client_handler.go index fdbbd17aaa..843f16afd3 100644 --- a/pkg/remote/trans/netpoll/client_handler.go +++ b/pkg/remote/trans/netpoll/client_handler.go @@ -19,6 +19,12 @@ package netpoll import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans" + "github.com/cloudwego/kitex/pkg/remote/trans/ttheaderstreaming" +) + +var ( + _ remote.ClientTransHandlerFactory = (*cliTransHandlerFactory)(nil) + _ remote.ClientStreamTransHandlerFactory = (*cliTransHandlerFactory)(nil) ) type cliTransHandlerFactory struct{} @@ -36,3 +42,7 @@ func (f *cliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remo func newCliTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { return trans.NewDefaultCliTransHandler(opt, NewNetpollConnExtension()) } + +func (f *cliTransHandlerFactory) NewStreamTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { + return ttheaderstreaming.NewCliTransHandler(opt, NewNetpollConnExtension()) +} diff --git a/pkg/remote/trans/netpoll/server_handler.go b/pkg/remote/trans/netpoll/server_handler.go index f2e703ce77..ae8dff6d3c 100644 --- a/pkg/remote/trans/netpoll/server_handler.go +++ b/pkg/remote/trans/netpoll/server_handler.go @@ -19,6 +19,7 @@ package netpoll import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans" + "github.com/cloudwego/kitex/pkg/remote/trans/ttheaderstreaming" ) type svrTransHandlerFactory struct{} @@ -36,3 +37,13 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo func newSvrTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { return trans.NewDefaultSvrTransHandler(opt, NewNetpollConnExtension()) } + +type ttheaderStreamingSvrTransHandlerFactory struct{} + +func (s *ttheaderStreamingSvrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { + return ttheaderstreaming.NewSvrTransHandler(opt, NewNetpollConnExtension()) +} + +func NewTTHeaderStreamingSvrTransHandlerFactory() remote.ServerTransHandlerFactory { + return &ttheaderStreamingSvrTransHandlerFactory{} +} diff --git a/pkg/remote/trans/nphttp2/client_handler.go b/pkg/remote/trans/nphttp2/client_handler.go index 9b98df5fbb..f5f98b4af7 100644 --- a/pkg/remote/trans/nphttp2/client_handler.go +++ b/pkg/remote/trans/nphttp2/client_handler.go @@ -25,8 +25,12 @@ import ( "github.com/cloudwego/kitex/pkg/remote/codec/grpc" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/rpcinfo" + kitex "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streaming" ) +var _ remote.ClientStreamTransHandlerFactory = (*cliTransHandlerFactory)(nil) + type cliTransHandlerFactory struct{} // NewCliTransHandlerFactory ... @@ -38,6 +42,10 @@ func (f *cliTransHandlerFactory) NewTransHandler(opt *remote.ClientOption) (remo return newCliTransHandler(opt) } +func (f *cliTransHandlerFactory) NewStreamTransHandler(opt *remote.ClientOption) (remote.ClientTransHandler, error) { + return newCliTransHandler(opt) +} + func newCliTransHandler(opt *remote.ClientOption) (*cliTransHandler, error) { return &cliTransHandler{ opt: opt, @@ -45,13 +53,22 @@ func newCliTransHandler(opt *remote.ClientOption) (*cliTransHandler, error) { }, nil } -var _ remote.ClientTransHandler = &cliTransHandler{} +var ( + _ remote.ClientTransHandler = &cliTransHandler{} + _ remote.ClientStreamAllocator = &cliTransHandler{} +) type cliTransHandler struct { opt *remote.ClientOption codec remote.Codec } +func (h *cliTransHandler) NewStream( + ctx context.Context, svcInfo *kitex.ServiceInfo, conn net.Conn, handler remote.TransReadWriter, +) (streaming.Stream, error) { + return NewStream(ctx, svcInfo, conn, handler), nil +} + func (h *cliTransHandler) Write(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { buf := newBuffer(conn.(*clientConn)) defer buf.Release(err) diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index dd99c38915..b3e702c549 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -200,7 +200,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { } rawStream := NewStream(rCtx, svcInfo, newServerConn(tr, s), t) - st := newStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint) + st := endpoint.NewStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint) // bind stream into ctx, in order to let user set header and trailer by provided api in meta_api.go rCtx = streaming.NewCtxWithStream(rCtx, st) diff --git a/pkg/remote/trans/ttheaderstreaming/client_handler.go b/pkg/remote/trans/ttheaderstreaming/client_handler.go new file mode 100644 index 0000000000..f7c0eb5814 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/client_handler.go @@ -0,0 +1,144 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "fmt" + "net" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/ttheader" + "github.com/cloudwego/kitex/pkg/remote/trans" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/transport" +) + +var _ remote.ClientStreamAllocator = (*ttheaderStreamingClientTransHandler)(nil) + +func NewCliTransHandler(opt *remote.ClientOption, ext trans.Extension) (remote.ClientTransHandler, error) { + return &ttheaderStreamingClientTransHandler{ + opt: opt, + codec: ttheader.NewStreamCodec(), + ext: ext, + frameMetaHandler: getClientFrameMetaHandler(opt.TTHeaderFrameMetaHandler), + }, nil +} + +func getClientFrameMetaHandler(handler remote.FrameMetaHandler) remote.FrameMetaHandler { + if handler != nil { + return handler + } + return NewClientTTHeaderFrameHandler() +} + +type ttheaderStreamingClientTransHandler struct { + opt *remote.ClientOption + codec remote.Codec + ext trans.Extension + frameMetaHandler remote.FrameMetaHandler +} + +func (t *ttheaderStreamingClientTransHandler) newHeaderMessage( + ctx context.Context, svcInfo *serviceinfo.ServiceInfo, +) (nCtx context.Context, message remote.Message, err error) { + ri := rpcinfo.GetRPCInfo(ctx) + message = remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Client) + message.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec)) + message.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + + if err = t.frameMetaHandler.WriteHeader(ctx, message); err != nil { + return ctx, nil, fmt.Errorf("ClientFrameProcessor.WriteHeader failed: %w", err) + } + + for _, h := range t.opt.StreamingMetaHandlers { + if ctx, err = h.OnConnectStream(ctx); err != nil { + return ctx, nil, fmt.Errorf("%T.OnConnectStream failed: %w", h, err) + } + } + return ctx, message, nil +} + +func (t *ttheaderStreamingClientTransHandler) NewStream( + ctx context.Context, svcInfo *serviceinfo.ServiceInfo, conn net.Conn, handler remote.TransReadWriter, +) (streaming.Stream, error) { + ctx, headerMessage, err := t.newHeaderMessage(ctx, svcInfo) + if err != nil { + return nil, err + } + + rawStream := newTTHeaderStream( + ctx, conn, headerMessage.RPCInfo(), t.codec, remote.Client, t.ext, t.frameMetaHandler) + if t.opt.TTHeaderStreamingGRPCCompatible { + rawStream.loadOutgoingMetadataForGRPC() + } + defer func() { + if err != nil { + _ = rawStream.closeSend(err) // make sure the sendLoop is released + } + }() + + if err = rawStream.sendHeaderFrame(headerMessage); err != nil { + return nil, err + } + + if t.opt.TTHeaderStreamingWaitMetaFrame { + // Not enabled by default for performance reason, and should only be enabled when a MetaFrame will + // definitely be returned, otherwise it will cause an error or even a hang + // For example, when service discovery is taken over by a proxy, the proxy should return a MetaFrame, + // so that Kitex client can get the real address and set it into RPCInfo; otherwise the tracer may not + // be able to get the right address (e.g. for Client Streaming send events) + if err = rawStream.clientReadMetaFrame(); err != nil { + return nil, err + } + } + return rawStream, nil +} + +// Write is not used and should never be called +func (t *ttheaderStreamingClientTransHandler) Write(ctx context.Context, conn net.Conn, send remote.Message) (nctx context.Context, err error) { + panic("not used") +} + +// Read is not used and should never be called +func (t *ttheaderStreamingClientTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { + panic("not used") +} + +// OnInactive is not used and should never be called +func (t *ttheaderStreamingClientTransHandler) OnInactive(ctx context.Context, conn net.Conn) { + panic("not used") +} + +// OnError is not used and should never be called +func (t *ttheaderStreamingClientTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { + panic("not used") +} + +// OnMessage is not used and should never be called +func (t *ttheaderStreamingClientTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) { + panic("not used") +} + +// SetPipeline is not used and should never be called +func (t *ttheaderStreamingClientTransHandler) SetPipeline(pipeline *remote.TransPipeline) { + panic("not used") +} diff --git a/pkg/remote/trans/ttheaderstreaming/client_handler_test.go b/pkg/remote/trans/ttheaderstreaming/client_handler_test.go new file mode 100644 index 0000000000..1de03ffc33 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/client_handler_test.go @@ -0,0 +1,278 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/ttheader" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/transport" +) + +func Test_getClientFrameMetaHandler(t *testing.T) { + t.Run("nil", func(t *testing.T) { + h := getClientFrameMetaHandler(nil) + _, ok := h.(*clientTTHeaderFrameHandler) + test.Assert(t, ok, h) + }) + t.Run("non-nil", func(t *testing.T) { + h := getClientFrameMetaHandler(&mockFrameMetaHandler{}) + _, ok := h.(*mockFrameMetaHandler) + test.Assert(t, ok, h) + }) +} + +func Test_ttheaderStreamingClientTransHandler_newHeaderMessage(t *testing.T) { + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: "xxx", + PayloadCodec: serviceinfo.Protobuf, + } + + t.Run("newHeaderMessage", func(t *testing.T) { + metaHandler := &mockFrameMetaHandler{ + writeHeader: func(ctx context.Context, message remote.Message) error { + message.TransInfo().TransStrInfo()["xxx"] = "1" + return nil + }, + onConnectStream: func(ctx context.Context) (context.Context, error) { + return context.WithValue(ctx, "yyy", "2"), nil + }, + } + opt := &remote.ClientOption{ + Option: remote.Option{ + TTHeaderFrameMetaHandler: metaHandler, + StreamingMetaHandlers: []remote.StreamingMetaHandler{metaHandler}, + }, + } + h, err := NewCliTransHandler(opt, nil) + test.Assert(t, err == nil, err) + + clientHandler := h.(*ttheaderStreamingClientTransHandler) + + ivk := rpcinfo.NewInvocation("idl-service-name", "to-method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + ctx, msg, err := clientHandler.newHeaderMessage(ctx, svcInfo) + test.Assert(t, err == nil, err) + test.Assert(t, msg.MessageType() == remote.Stream, msg.MessageType()) + test.Assert(t, msg.RPCRole() == remote.Client, msg.RPCRole()) + test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeader, msg.ProtocolInfo()) + test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf, msg.ProtocolInfo()) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, msg.TransInfo().TransIntInfo()) + + // WriteHeader check + test.Assert(t, msg.TransInfo().TransStrInfo()["xxx"] == "1", msg.TransInfo().TransStrInfo()) + + // StreamingMetaHandler check + test.Assert(t, ctx.Value("yyy") == "2", ctx.Value("yyy")) + }) + + t.Run("write-header-err", func(t *testing.T) { + metaHandler := &mockFrameMetaHandler{ + writeHeader: func(ctx context.Context, message remote.Message) error { + return errors.New("xxx") + }, + } + opt := &remote.ClientOption{ + Option: remote.Option{ + TTHeaderFrameMetaHandler: metaHandler, + }, + } + h, err := NewCliTransHandler(opt, nil) + test.Assert(t, err == nil, err) + + clientHandler := h.(*ttheaderStreamingClientTransHandler) + + ivk := rpcinfo.NewInvocation("idl-service-name", "to-method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + _, _, err = clientHandler.newHeaderMessage(ctx, svcInfo) + test.Assert(t, err != nil, err) + }) + + t.Run("OnConnectStream-err", func(t *testing.T) { + metaHandler := &mockFrameMetaHandler{ + writeHeader: func(ctx context.Context, message remote.Message) error { + return nil + }, + onConnectStream: func(ctx context.Context) (context.Context, error) { + return nil, errors.New("xxx") + }, + } + opt := &remote.ClientOption{ + Option: remote.Option{ + TTHeaderFrameMetaHandler: metaHandler, + StreamingMetaHandlers: []remote.StreamingMetaHandler{metaHandler}, + }, + } + h, err := NewCliTransHandler(opt, nil) + test.Assert(t, err == nil, err) + + clientHandler := h.(*ttheaderStreamingClientTransHandler) + + ivk := rpcinfo.NewInvocation("idl-service-name", "to-method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + _, _, err = clientHandler.newHeaderMessage(ctx, svcInfo) + test.Assert(t, err != nil, err) + }) +} + +func Test_ttheaderStreamingClientTransHandler_NewStream(t *testing.T) { + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: "xxx", + PayloadCodec: serviceinfo.Protobuf, + } + + t.Run("new-header-err", func(t *testing.T) { + metaHandler := &mockFrameMetaHandler{ + writeHeader: func(ctx context.Context, message remote.Message) error { + return errors.New("xxx") + }, + } + opt := &remote.ClientOption{ + Option: remote.Option{ + TTHeaderFrameMetaHandler: metaHandler, + }, + } + h, err := NewCliTransHandler(opt, nil) + test.Assert(t, err == nil, err) + + clientHandler := h.(*ttheaderStreamingClientTransHandler) + + ivk := rpcinfo.NewInvocation("idl-service-name", "to-method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + _, err = clientHandler.NewStream(ctx, svcInfo, nil, nil) + test.Assert(t, err != nil, err) + }) + + t.Run("normal-with-grpc-meta-and-wait-meta-frame", func(t *testing.T) { + opt := &remote.ClientOption{ + TTHeaderStreamingGRPCCompatible: true, + TTHeaderStreamingWaitMetaFrame: true, + Option: remote.Option{ + TTHeaderFrameMetaHandler: &mockFrameMetaHandler{ + readMeta: func(ctx context.Context, msg remote.Message) error { + to := rpcinfo.AsMutableEndpointInfo(msg.RPCInfo().To()) + _ = to.SetTag("meta_key", msg.TransInfo().TransStrInfo()["meta_key"]) + return nil + }, + writeHeader: func(ctx context.Context, message remote.Message) error { + return nil + }, + }, + }, + } + ext := &mockExtension{} + h, err := NewCliTransHandler(opt, ext) + test.Assert(t, err == nil, err) + clientHandler := h.(*ttheaderStreamingClientTransHandler) + + writeChan := make(chan []byte, 3) + conn := &mockConn{ + readBuf: func() []byte { + ivk := rpcinfo.NewInvocation("idl-service-name", "to-method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeMeta + msg.TransInfo().TransStrInfo()["meta_key"] = "meta_value" + out := remote.NewWriterBuffer(1024) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + if err = ttheader.NewStreamCodec().Encode(ctx, msg, out); err != nil { + panic(err) + } + buf, _ := out.Bytes() + return buf + }(), + write: func(b []byte) (int, error) { + writeBuf := make([]byte, len(b)) + n := copy(writeBuf, b) + writeChan <- writeBuf + return n, nil + }, + } + + ivk := rpcinfo.NewInvocation("idl-service-name", "to-method") + to := rpcinfo.NewMutableEndpointInfo("to-service", "to-instance", nil, nil) + ri := rpcinfo.NewRPCInfo(nil, to.ImmutableView(), ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx = metadata.AppendToOutgoingContext(ctx, "k1", "v1", "k1", "v2") + + st, err := clientHandler.NewStream(ctx, svcInfo, conn, nil) + test.Assert(t, err == nil, err) + _ = st.Close() + + buf := <-writeChan + test.Assert(t, len(buf) > 0, buf) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Client) + err = codec.TTHeaderCodec().DecodeHeader(ctx, msg, remote.NewReaderBuffer(buf)) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, msg.TransInfo().TransIntInfo()) + meta := msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] + test.Assert(t, meta == `{"k1":["v1","v2"]}`, meta) + + v, exists := ri.To().Tag("meta_key") + test.Assert(t, exists) + test.Assert(t, v == "meta_value") + }) + + t.Run("read-meta-frame-err", func(t *testing.T) { + opt := &remote.ClientOption{ + TTHeaderStreamingWaitMetaFrame: true, + } + ext := &mockExtension{} + h, err := NewCliTransHandler(opt, ext) + test.Assert(t, err == nil, err) + clientHandler := h.(*ttheaderStreamingClientTransHandler) + + writeChan := make(chan []byte, 3) + expectedErr := errors.New("read-meta-err") + conn := &mockConn{ + read: func(b []byte) (int, error) { + return -1, expectedErr + }, + write: func(b []byte) (int, error) { + writeBuf := make([]byte, len(b)) + n := copy(writeBuf, b) + writeChan <- writeBuf + return n, nil + }, + } + + ivk := rpcinfo.NewInvocation("idl-service-name", "to-method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + st, err := clientHandler.NewStream(ctx, svcInfo, conn, nil) + test.Assert(t, expectedErr.Error() == err.Error(), err) + test.Assert(t, st == nil, st) + }) +} diff --git a/pkg/remote/trans/ttheaderstreaming/frame_meta_handler.go b/pkg/remote/trans/ttheaderstreaming/frame_meta_handler.go new file mode 100644 index 0000000000..3a663d0b00 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/frame_meta_handler.go @@ -0,0 +1,209 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" + + "github.com/bytedance/gopkg/cloud/metainfo" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + header_transmeta "github.com/cloudwego/kitex/pkg/transmeta" +) + +var ( + _ remote.FrameMetaHandler = (*clientTTHeaderFrameHandler)(nil) + _ remote.FrameMetaHandler = (*serverTTHeaderFrameHandler)(nil) +) + +type clientTTHeaderFrameHandler struct{} + +func NewClientTTHeaderFrameHandler() remote.FrameMetaHandler { + return &clientTTHeaderFrameHandler{} +} + +// ReadMeta does nothing by default (only for client) +func (c *clientTTHeaderFrameHandler) ReadMeta(ctx context.Context, message remote.Message) error { + return nil +} + +// ClientReadHeader does nothing by default at client side +func (c *clientTTHeaderFrameHandler) ClientReadHeader(ctx context.Context, message remote.Message) error { + return nil +} + +// ServerReadHeader should not be called at client side +func (c *clientTTHeaderFrameHandler) ServerReadHeader(ctx context.Context, message remote.Message) (context.Context, error) { + panic("ServerReadHeader should not be called at client side") +} + +// WriteHeader writes necessary keys for TTHeader Streaming +func (c *clientTTHeaderFrameHandler) WriteHeader(ctx context.Context, msg remote.Message) error { + ri := msg.RPCInfo() + + intInfo := map[uint16]string{ + transmeta.ToMethod: ri.Invocation().MethodName(), + } + if deadline, ok := ctx.Deadline(); ok { + if now := time.Now(); deadline.After(now) { + intInfo[transmeta.RPCTimeout] = strconv.Itoa(int(deadline.Sub(now).Milliseconds())) + } + } + msg.TransInfo().PutTransIntInfo(intInfo) + + strInfo := msg.TransInfo().TransStrInfo() + metainfo.SaveMetaInfoToMap(ctx, strInfo) + strInfo[transmeta.HeaderIDLServiceName] = ri.Invocation().ServiceName() + return nil +} + +func (c *clientTTHeaderFrameHandler) ReadTrailer(ctx context.Context, msg remote.Message) error { + intInfo := msg.TransInfo().TransIntInfo() + strInfo := msg.TransInfo().TransStrInfo() + + // TransError + if err := parseTransErrorIfExists(intInfo); err != nil { + return err + } + + // BizStatusError + if err := header_transmeta.ParseBizStatusErrorToRPCInfo(strInfo, msg.RPCInfo()); err != nil { + return err + } + + // TODO: metainfo backward keys + return nil +} + +func parseTransErrorIfExists(intInfo map[uint16]string) error { + if transCode, exists := intInfo[transmeta.TransCode]; exists { + if code, err := strconv.Atoi(transCode); err != nil { + msg := fmt.Sprintf("parse trans code failed, code: %s, err: %s", transCode, err) + return remote.NewTransErrorWithMsg(remote.InternalError, msg) + } else if code != 0 { + transMsg := intInfo[transmeta.TransMessage] + return remote.NewTransErrorWithMsg(int32(code), transMsg) + } + } + return nil +} + +// WriteTrailer for client +func (c *clientTTHeaderFrameHandler) WriteTrailer(ctx context.Context, msg remote.Message) error { + if err, ok := msg.Data().(error); ok && err != nil { + injectTransError(msg, err) + } + return nil +} + +type serverTTHeaderFrameHandler struct{} + +func NewServerTTHeaderFrameHandler() remote.FrameMetaHandler { + return &serverTTHeaderFrameHandler{} +} + +// ReadMeta should not be called at server side +func (s *serverTTHeaderFrameHandler) ReadMeta(ctx context.Context, message remote.Message) error { + panic("ReadMeta should not be called at server side") +} + +// ClientReadHeader should not be called at server side +func (s *serverTTHeaderFrameHandler) ClientReadHeader(ctx context.Context, message remote.Message) error { + panic("ClientReadHeader should not be called at server side") +} + +// ServerReadHeader reads necessary keys for TTHeader Streaming +func (s *serverTTHeaderFrameHandler) ServerReadHeader(ctx context.Context, msg remote.Message) (context.Context, error) { + ri := msg.RPCInfo() + transInfo := msg.TransInfo() + intInfo := transInfo.TransIntInfo() + + if cfg := rpcinfo.AsMutableRPCConfig(ri.Config()); cfg != nil { + cfg.SetPayloadCodec(msg.ProtocolInfo().CodecType) + timeout := intInfo[transmeta.RPCTimeout] + if timeoutMS, err := strconv.Atoi(timeout); err == nil { + _ = cfg.SetRPCTimeout(time.Duration(timeoutMS) * time.Millisecond) + } + } + + ink := ri.Invocation().(rpcinfo.InvocationSetter) + if method, exists := intInfo[transmeta.ToMethod]; !exists { // method must exist in Header Frame + ErrHeaderFrameInvalidToMethod := errors.New("missing method in ttheader streaming header frame") + return ctx, ErrHeaderFrameInvalidToMethod + } else { + ink.SetMethodName(method) + } + + if idlServiceName, exists := transInfo.TransStrInfo()[transmeta.HeaderIDLServiceName]; exists { + // IDL Service Name is only necessary for multi-service servers. + ink.SetServiceName(idlServiceName) + } + + // cloudwego metainfo + ctx = metainfo.SetMetaInfoFromMap(ctx, msg.TransInfo().TransStrInfo()) + + // grpc style metadata + return injectGRPCMetadata(ctx, msg.TransInfo().TransStrInfo()) +} + +// WriteHeader does nothing by default at server side +func (s *serverTTHeaderFrameHandler) WriteHeader(ctx context.Context, msg remote.Message) error { + return nil +} + +// ReadTrailer for server +func (s *serverTTHeaderFrameHandler) ReadTrailer(ctx context.Context, msg remote.Message) error { + intInfo := msg.TransInfo().TransIntInfo() + + // TransError from Client + if err := parseTransErrorIfExists(intInfo); err != nil { + return err + } + + return nil +} + +// WriteTrailer ... +func (s *serverTTHeaderFrameHandler) WriteTrailer(ctx context.Context, msg remote.Message) error { + // TODO: metainfo backward keys + if err, ok := msg.Data().(error); ok && err != nil { + injectTransError(msg, err) + return nil // no need to deal with BizStatusError + } + header_transmeta.InjectBizStatusError(msg.TransInfo().TransStrInfo(), msg.RPCInfo().Invocation().BizStatusErr()) + return nil +} + +// injectTransError set an error to message header +// if err is not a remote.TransError, the exceptionType will be set to remote.InternalError +func injectTransError(message remote.Message, err error) { + exceptionType := remote.InternalError + var transErr *remote.TransError + if errors.As(err, &transErr) { + exceptionType = int(transErr.TypeID()) + } + message.TransInfo().PutTransIntInfo(map[uint16]string{ + transmeta.TransCode: strconv.Itoa(int(exceptionType)), + transmeta.TransMessage: err.Error(), + }) +} diff --git a/pkg/remote/trans/ttheaderstreaming/frame_meta_handler_test.go b/pkg/remote/trans/ttheaderstreaming/frame_meta_handler_test.go new file mode 100644 index 0000000000..6b7733f257 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/frame_meta_handler_test.go @@ -0,0 +1,370 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "strconv" + "strings" + "testing" + "time" + + "github.com/bytedance/gopkg/cloud/metainfo" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + header_transmeta "github.com/cloudwego/kitex/pkg/transmeta" +) + +func Test_clientTTHeaderFrameHandler_ServerReadHeader(t *testing.T) { + defer func() { + if p := recover(); p != nil { + t.Logf("expected panic") + } else { + t.Errorf("missing expected panic") + } + }() + h := NewClientTTHeaderFrameHandler() + _, _ = h.ServerReadHeader(context.Background(), nil) +} + +func Test_clientTTHeaderFrameHandler_WriteHeader(t *testing.T) { + ivk := rpcinfo.NewInvocation("idl-service-name", "method-name") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + + h := NewClientTTHeaderFrameHandler() + _ = h.WriteHeader(context.Background(), msg) + + intInfo := msg.TransInfo().TransIntInfo() + strInfo := msg.TransInfo().TransStrInfo() + test.Assert(t, intInfo[transmeta.ToMethod] == ivk.MethodName(), intInfo) + test.Assert(t, strInfo[transmeta.HeaderIDLServiceName] == ivk.ServiceName(), strInfo) +} + +func Test_clientTTHeaderFrameHandler_ReadTrailer(t *testing.T) { + t.Run("trans-error-parse-failure", func(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.TransCode] = "X" + + h := NewClientTTHeaderFrameHandler() + err := h.ReadTrailer(context.Background(), msg) + test.Assert(t, err != nil, err) + transErr, ok := err.(*remote.TransError) + test.Assert(t, ok, err) + test.Assert(t, transErr.TypeID() == remote.InternalError, transErr) + }) + + t.Run("trans-error", func(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.TransCode] = "1204" + msg.TransInfo().TransIntInfo()[transmeta.TransMessage] = "mesh timeout" + + h := NewClientTTHeaderFrameHandler() + err := h.ReadTrailer(context.Background(), msg) + test.Assert(t, err != nil, err) + transErr, ok := err.(*remote.TransError) + test.Assert(t, ok, err) + test.Assert(t, transErr.TypeID() == 1204, transErr) + test.Assert(t, transErr.Error() == "mesh timeout", transErr) + }) + + t.Run("biz-status-error-parse-failure", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + bizErr := kerrors.NewBizStatusError(100, "biz error") + ri.Invocation().(rpcinfo.InvocationSetter).SetBizStatusErr(bizErr) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + header_transmeta.InjectBizStatusError(msg.TransInfo().TransStrInfo(), ri.Invocation().BizStatusErr()) + msg.TransInfo().TransStrInfo()["biz-extra"] = "invalid biz extra" + + h := NewClientTTHeaderFrameHandler() + err := h.ReadTrailer(context.Background(), msg) + test.Assert(t, err != nil, err) + }) + + t.Run("biz-status-error", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + bizErr := kerrors.NewBizStatusError(100, "biz error") + ri.Invocation().(rpcinfo.InvocationSetter).SetBizStatusErr(bizErr) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + header_transmeta.InjectBizStatusError(msg.TransInfo().TransStrInfo(), ri.Invocation().BizStatusErr()) + + h := NewClientTTHeaderFrameHandler() + err := h.ReadTrailer(context.Background(), msg) + test.Assert(t, err == nil, err) + + gotBizErr := ri.Invocation().BizStatusErr() + test.Assert(t, gotBizErr.BizStatusCode() == bizErr.BizStatusCode(), gotBizErr) + test.Assert(t, gotBizErr.BizMessage() == bizErr.BizMessage(), gotBizErr) + }) +} + +func Test_parseTransErrorIfExists(t *testing.T) { + t.Run("no-trans-code", func(t *testing.T) { + test.Assert(t, parseTransErrorIfExists(map[uint16]string{}) == nil) + }) + t.Run("invalid-trans-code", func(t *testing.T) { + err := parseTransErrorIfExists(map[uint16]string{transmeta.TransCode: "X"}) + test.Assert(t, err != nil, err) + transErr := err.(*remote.TransError) + test.Assert(t, transErr.TypeID() == remote.InternalError, transErr) + }) + t.Run("code=0", func(t *testing.T) { + test.Assert(t, parseTransErrorIfExists(map[uint16]string{transmeta.TransCode: "0"}) == nil) + }) + t.Run("non-zero code", func(t *testing.T) { + intInfo := map[uint16]string{ + transmeta.TransCode: "1204", + transmeta.TransMessage: "mesh timeout", + } + err := parseTransErrorIfExists(intInfo) + test.Assert(t, err != nil, err) + transErr := err.(*remote.TransError) + test.Assert(t, transErr.TypeID() == 1204, transErr) + test.Assert(t, transErr.Error() == "mesh timeout", transErr) + }) +} + +func Test_clientTTHeaderFrameHandler_WriteTrailer(t *testing.T) { + t.Run("nil-trans-error", func(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Client) + + h := NewClientTTHeaderFrameHandler() + err := h.WriteTrailer(context.Background(), msg) + test.Assert(t, err == nil, err) + + intInfo := msg.TransInfo().TransIntInfo() + test.Assert(t, len(intInfo) == 0) + }) + + t.Run("trans-error", func(t *testing.T) { + transErr := remote.NewTransError(1204, errors.New("mesh timeout")) + msg := remote.NewMessage(transErr, nil, nil, remote.Stream, remote.Client) + + h := NewClientTTHeaderFrameHandler() + err := h.WriteTrailer(context.Background(), msg) + test.Assert(t, err == nil, err) + + intInfo := msg.TransInfo().TransIntInfo() + test.Assert(t, len(intInfo) == 2) + test.Assert(t, intInfo[transmeta.TransCode] == "1204") + test.Assert(t, intInfo[transmeta.TransMessage] == "mesh timeout") + }) +} + +func Test_clientTTHeaderFrameHandler_WriteTrailer1(t *testing.T) { + t.Run("not-trans-error", func(t *testing.T) { + err := errors.New("xxx") + msg := remote.NewMessage(err, nil, nil, remote.Stream, remote.Client) + injectTransError(msg, err) + intInfo := msg.TransInfo().TransIntInfo() + test.Assert(t, intInfo[transmeta.TransCode] == strconv.Itoa(remote.InternalError), intInfo) + test.Assert(t, intInfo[transmeta.TransMessage] == err.Error(), intInfo) + }) + t.Run("trans-error", func(t *testing.T) { + transErr := remote.NewTransError(1204, errors.New("mesh timeout")) + msg := remote.NewMessage(transErr, nil, nil, remote.Stream, remote.Client) + injectTransError(msg, transErr) + intInfo := msg.TransInfo().TransIntInfo() + test.Assert(t, intInfo[transmeta.TransCode] == strconv.Itoa(1204), intInfo) + test.Assert(t, intInfo[transmeta.TransMessage] == "mesh timeout", intInfo) + }) +} + +func Test_serverTTHeaderFrameHandler_ReadMeta(t *testing.T) { + defer func() { + err := recover() + test.Assert(t, err != nil, err) + }() + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + h := NewServerTTHeaderFrameHandler() + _ = h.ReadMeta(context.Background(), msg) + t.Errorf("ReadMeta should not be called at server side") +} + +func Test_serverTTHeaderFrameHandler_ClientReadHeader(t *testing.T) { + defer func() { + err := recover() + test.Assert(t, err != nil, err) + }() + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + h := NewServerTTHeaderFrameHandler() + _ = h.ClientReadHeader(context.Background(), msg) + t.Errorf("ClientReadHeader should not be called at client side") +} + +func Test_serverTTHeaderFrameHandler_ServerReadHeader(t *testing.T) { + t.Run("missing-to-method", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), nil) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + h := NewServerTTHeaderFrameHandler() + _, err := h.ServerReadHeader(context.Background(), msg) + test.Assert(t, err != nil, err) + test.Assert(t, strings.HasPrefix(err.Error(), "missing method"), err) + }) + + t.Run("normal:to_method", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), nil) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + intInfo := msg.TransInfo().TransIntInfo() + intInfo[transmeta.ToMethod] = "to-method" + h := NewServerTTHeaderFrameHandler() + _, err := h.ServerReadHeader(context.Background(), msg) + test.Assert(t, err == nil, err) + test.Assert(t, ivk.MethodName() == "to-method", ivk.MethodName()) + test.Assert(t, ri.Config().RPCTimeout() == 0, ri.Config().RPCTimeout()) + }) + + t.Run("normal:+rpcTimeout", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), nil) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + intInfo := msg.TransInfo().TransIntInfo() + intInfo[transmeta.ToMethod] = "to-method" + intInfo[transmeta.RPCTimeout] = "1000" // ms + h := NewServerTTHeaderFrameHandler() + _, err := h.ServerReadHeader(context.Background(), msg) + test.Assert(t, err == nil, err) + test.Assert(t, ri.Config().RPCTimeout() == 1000*time.Millisecond, ri.Config().RPCTimeout()) + }) + + t.Run("normal:+isn", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), nil) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + intInfo := msg.TransInfo().TransIntInfo() + intInfo[transmeta.ToMethod] = "to-method" + msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName] = "isn" + h := NewServerTTHeaderFrameHandler() + _, err := h.ServerReadHeader(context.Background(), msg) + test.Assert(t, err == nil, err) + test.Assert(t, ivk.ServiceName() == "isn", ivk.ServiceName()) + }) + + t.Run("normal:+metainfo", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), nil) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = "to-method" + strInfo := msg.TransInfo().TransStrInfo() + strInfo[metainfo.PrefixTransient+"key"] = "value" + strInfo[metainfo.PrefixPersistent+"key"] = "value" + + h := NewServerTTHeaderFrameHandler() + ctx, err := h.ServerReadHeader(context.Background(), msg) + test.Assert(t, err == nil, err) + + value, _ := metainfo.GetValue(ctx, "key") + test.Assert(t, value == "value", value) + + pValue, _ := metainfo.GetPersistentValue(ctx, "key") + test.Assert(t, pValue == "value", pValue) + }) + + t.Run("normal:+grpc_metadata", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), nil) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = "to-method" + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"key":["v1", "v2"]}` + + h := NewServerTTHeaderFrameHandler() + ctx, err := h.ServerReadHeader(context.Background(), msg) + test.Assert(t, err == nil, err) + + md, ok := metadata.FromIncomingContext(ctx) + test.Assert(t, ok, ok) + test.Assert(t, md["key"][0] == "v1", md["key"][0]) + test.Assert(t, md["key"][1] == "v2", md["key"][1]) + }) +} + +func Test_serverTTHeaderFrameHandler_ReadTrailer(t *testing.T) { + t.Run("trans-error-parse-failure", func(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.TransCode] = "X" + + h := NewServerTTHeaderFrameHandler() + err := h.ReadTrailer(context.Background(), msg) + test.Assert(t, err != nil, err) + transErr, ok := err.(*remote.TransError) + test.Assert(t, ok, err) + test.Assert(t, transErr.TypeID() == remote.InternalError, transErr) + }) + + t.Run("trans-error", func(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.TransCode] = "999" + msg.TransInfo().TransIntInfo()[transmeta.TransMessage] = "client err" + + h := NewServerTTHeaderFrameHandler() + err := h.ReadTrailer(context.Background(), msg) + test.Assert(t, err != nil, err) + transErr, ok := err.(*remote.TransError) + test.Assert(t, ok, err) + test.Assert(t, transErr.TypeID() == 999, transErr) + test.Assert(t, transErr.Error() == "client err", transErr) + }) +} + +func Test_serverTTHeaderFrameHandler_WriteTrailer(t *testing.T) { + t.Run("trans-error", func(t *testing.T) { + transErr := remote.NewTransError(999, errors.New("client err")) + msg := remote.NewMessage(transErr, nil, nil, remote.Stream, remote.Server) + + h := NewClientTTHeaderFrameHandler() + err := h.WriteTrailer(context.Background(), msg) + test.Assert(t, err == nil, err) + + intInfo := msg.TransInfo().TransIntInfo() + test.Assert(t, len(intInfo) == 2) + test.Assert(t, intInfo[transmeta.TransCode] == strconv.Itoa(int(transErr.TypeID()))) + test.Assert(t, intInfo[transmeta.TransMessage] == transErr.Error()) + }) + t.Run("biz-err", func(t *testing.T) { + ivk := rpcinfo.NewInvocation("", "") + bizErr := kerrors.NewBizStatusErrorWithExtra(999, "biz err", map[string]string{"k": "v"}) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + ri.Invocation().(rpcinfo.InvocationSetter).SetBizStatusErr(bizErr) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + h := NewServerTTHeaderFrameHandler() + err := h.WriteTrailer(context.Background(), msg) + test.Assert(t, err == nil, err) + + strInfo := msg.TransInfo().TransStrInfo() + newIvk := rpcinfo.NewInvocation("", "") + newRI := rpcinfo.NewRPCInfo(nil, nil, newIvk, rpcinfo.NewRPCConfig(), nil) + err = header_transmeta.ParseBizStatusErrorToRPCInfo(strInfo, newRI) + test.Assert(t, err == nil, err) + + gotBizErr := newRI.Invocation().BizStatusErr() + test.Assert(t, gotBizErr != nil, gotBizErr) + test.Assert(t, gotBizErr.BizStatusCode() == bizErr.BizStatusCode(), bizErr) + test.Assert(t, gotBizErr.BizMessage() == bizErr.BizMessage(), bizErr) + test.Assert(t, gotBizErr.BizExtra()["k"] == "v", bizErr) + }) +} diff --git a/pkg/remote/trans/ttheaderstreaming/grpc_metadata.go b/pkg/remote/trans/ttheaderstreaming/grpc_metadata.go new file mode 100644 index 0000000000..b5d1263254 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/grpc_metadata.go @@ -0,0 +1,110 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/bytedance/sonic" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" +) + +// grpcMetadata contains the grpc style metadata, helpful for projects migrating from kitex-grpc +type grpcMetadata struct { + lock sync.Mutex + + headersToSend metadata.MD + trailerToSend metadata.MD + + headersReceived metadata.MD + trailerReceived metadata.MD +} + +func (g *grpcMetadata) setHeader(md metadata.MD) { + if md.Len() == 0 { + return + } + g.lock.Lock() + g.headersToSend = metadata.AppendMD(g.headersToSend, md) + g.lock.Unlock() +} + +func (g *grpcMetadata) setTrailer(md metadata.MD) { + if md.Len() == 0 { + return + } + g.lock.Lock() + g.trailerToSend = metadata.AppendMD(g.trailerToSend, md) + g.lock.Unlock() +} + +func (g *grpcMetadata) getMetadataAsJSON(md metadata.MD) ([]byte, error) { + if len(md) == 0 { + return nil, nil + } + return json.Marshal(md) +} + +func (g *grpcMetadata) injectHeader(message remote.Message) error { + return g.injectMetadata(g.headersToSend, message) +} + +func (g *grpcMetadata) injectTrailer(message remote.Message) error { + return g.injectMetadata(g.trailerToSend, message) +} + +func (g *grpcMetadata) injectMetadata(md metadata.MD, message remote.Message) error { + if metadataJSON, err := g.getMetadataAsJSON(md); err != nil { + return fmt.Errorf("failed to get metadata as json: err=%w", err) + } else if len(metadataJSON) > 0 { + message.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = string(metadataJSON) + } + return nil +} + +func (g *grpcMetadata) parseHeader(message remote.Message) error { + return g.parseMetadata(message, &g.headersReceived) +} + +func (g *grpcMetadata) parseTrailer(message remote.Message) error { + return g.parseMetadata(message, &g.trailerReceived) +} + +func (g *grpcMetadata) parseMetadata(message remote.Message, target *metadata.MD) error { + if metadataJSON, exists := message.TransInfo().TransStrInfo()[codec.StrKeyMetaData]; exists { + md := metadata.MD{} + if err := sonic.Unmarshal([]byte(metadataJSON), &md); err != nil { + return fmt.Errorf("invalid metadata: json=%s, err = %w", metadataJSON, err) + } + g.lock.Lock() + *target = md + g.lock.Unlock() + } + return nil +} + +func (g *grpcMetadata) loadHeadersToSend(ctx context.Context) { + if md, ok := metadata.FromOutgoingContext(ctx); ok { + g.headersToSend = md + } +} diff --git a/pkg/remote/trans/ttheaderstreaming/grpc_metadata_test.go b/pkg/remote/trans/ttheaderstreaming/grpc_metadata_test.go new file mode 100644 index 0000000000..3ba79470b2 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/grpc_metadata_test.go @@ -0,0 +1,168 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "reflect" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" +) + +func Test_grpcMetadata_setHeader(t *testing.T) { + t.Run("set-with-empty-md", func(t *testing.T) { + md := metadata.MD{"k2": []string{"v2"}} + g := &grpcMetadata{headersToSend: md} + g.setHeader(metadata.MD{}) + test.Assert(t, reflect.DeepEqual(g.headersToSend, md), g.headersToSend) + }) + + t.Run("set-to-empty-map", func(t *testing.T) { + md := metadata.MD{"k1": []string{"v1"}} + g := &grpcMetadata{} + g.setHeader(md) + test.Assert(t, reflect.DeepEqual(g.headersToSend, md), g.headersToSend) + }) + + t.Run("set-to-non-empty-map", func(t *testing.T) { + md := metadata.MD{"k1": []string{"v1"}} + g := &grpcMetadata{headersToSend: metadata.MD{"k2": []string{"v2"}}} + g.setHeader(md) + expected := metadata.MD{"k1": []string{"v1"}, "k2": []string{"v2"}} + test.Assert(t, reflect.DeepEqual(g.headersToSend, expected), g.headersToSend) + }) +} + +func Test_grpcMetadata_setTrailer(t *testing.T) { + t.Run("set-with-empty-md", func(t *testing.T) { + md := metadata.MD{"k2": []string{"v2"}} + g := &grpcMetadata{trailerToSend: md} + g.setTrailer(metadata.MD{}) + test.Assert(t, reflect.DeepEqual(g.trailerToSend, md), g.trailerToSend) + }) + + t.Run("set-to-empty-map", func(t *testing.T) { + md := metadata.MD{"k1": []string{"v1"}} + g := &grpcMetadata{} + g.setTrailer(md) + test.Assert(t, reflect.DeepEqual(g.trailerToSend, md), g.trailerToSend) + }) + + t.Run("set-to-non-empty-map", func(t *testing.T) { + md := metadata.MD{"k1": []string{"v1"}} + g := &grpcMetadata{trailerToSend: metadata.MD{"k2": []string{"v2"}}} + g.setTrailer(md) + expected := metadata.MD{"k1": []string{"v1"}, "k2": []string{"v2"}} + test.Assert(t, reflect.DeepEqual(g.trailerToSend, expected), g.trailerToSend) + }) +} + +func Test_grpcMetadata_getMetadataAsJSON(t *testing.T) { + t.Run("trailer", func(t *testing.T) { + g := &grpcMetadata{trailerToSend: metadata.MD{"k1": []string{"v1"}}} + md, err := g.getMetadataAsJSON(g.trailerToSend) + test.Assert(t, err == nil, err) + test.Assert(t, string(md) == `{"k1":["v1"]}`) + }) + + t.Run("empty-header", func(t *testing.T) { + g := &grpcMetadata{} + md, err := g.getMetadataAsJSON(nil) + test.Assert(t, err == nil, err) + test.Assert(t, md == nil) + }) +} + +func Test_grpcMetadata_injectMetadata(t *testing.T) { + t.Run("inject-header", func(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + g := &grpcMetadata{headersToSend: metadata.MD{"k1": []string{"v1"}}} + err := g.injectMetadata(g.headersToSend, msg) + test.Assert(t, err == nil, err) + + mdJSON, exists := msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] + test.Assert(t, exists) + test.Assert(t, string(mdJSON) == `{"k1":["v1"]}`) + }) +} + +func Test_grpcMetadata_parseMetadata(t *testing.T) { + t.Run("no-metadata", func(t *testing.T) { + md := metadata.MD{"k1": []string{"v1"}} + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + g := &grpcMetadata{} + err := g.parseMetadata(msg, &md) + test.Assert(t, err == nil) + test.Assert(t, md.Len() == 1) + }) + t.Run("with-metadata", func(t *testing.T) { + md := metadata.MD{} + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"k2":["v2"]}` + g := &grpcMetadata{} + err := g.parseMetadata(msg, &md) + test.Assert(t, err == nil) + test.Assert(t, md["k2"][0] == "v2") + }) + t.Run("with-invalid-metadata", func(t *testing.T) { + md := metadata.MD{"k1": []string{"v1"}} + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `invalid` + g := &grpcMetadata{} + err := g.parseMetadata(msg, &md) + test.Assert(t, err != nil) + test.Assert(t, md.Len() == 1) + }) +} + +func Test_grpcMetadata_parseHeader(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"k1":["v1"]}` + g := &grpcMetadata{} + err := g.parseHeader(msg) + test.Assert(t, err == nil) + test.Assert(t, g.headersReceived["k1"][0] == "v1") +} + +func Test_grpcMetadata_parseTrailer(t *testing.T) { + msg := remote.NewMessage(nil, nil, nil, remote.Stream, remote.Server) + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"k1":["v1"]}` + g := &grpcMetadata{} + err := g.parseTrailer(msg) + test.Assert(t, err == nil) + test.Assert(t, g.trailerReceived["k1"][0] == "v1") +} + +func Test_grpcMetadata_loadHeadersToSend(t *testing.T) { + t.Run("no-metadata-in-ctx", func(t *testing.T) { + g := &grpcMetadata{} + g.loadHeadersToSend(context.Background()) + test.Assert(t, len(g.headersToSend) == 0) + }) + + t.Run("with-metadata-in-ctx", func(t *testing.T) { + g := &grpcMetadata{} + ctx := metadata.AppendToOutgoingContext(context.Background(), "k1", "v1") + g.loadHeadersToSend(ctx) + test.Assert(t, len(g.headersToSend) == 1) + test.Assert(t, g.headersToSend["k1"][0] == "v1") + }) +} diff --git a/pkg/remote/trans/ttheaderstreaming/mock_test.go b/pkg/remote/trans/ttheaderstreaming/mock_test.go new file mode 100644 index 0000000000..9ebf773b0e --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/mock_test.go @@ -0,0 +1,352 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "io" + "net" + "time" + + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/ttheader" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streaming" +) + +type mockFrameMetaHandler struct { + remote.FrameMetaHandler + remote.StreamingMetaHandler + readMeta func(ctx context.Context, msg remote.Message) error + clientReadHeader func(ctx context.Context, msg remote.Message) error + serverReadHeader func(ctx context.Context, msg remote.Message) (context.Context, error) + writeHeader func(ctx context.Context, message remote.Message) error + readTrailer func(ctx context.Context, msg remote.Message) error + writeTrailer func(ctx context.Context, message remote.Message) error + onConnectStream func(ctx context.Context) (context.Context, error) +} + +func (m *mockFrameMetaHandler) ReadMeta(ctx context.Context, msg remote.Message) error { + if m.readMeta != nil { + return m.readMeta(ctx, msg) + } + return nil +} + +func (m *mockFrameMetaHandler) ClientReadHeader(ctx context.Context, msg remote.Message) error { + if m.clientReadHeader != nil { + return m.clientReadHeader(ctx, msg) + } + return nil +} + +func (m *mockFrameMetaHandler) ServerReadHeader(ctx context.Context, msg remote.Message) (context.Context, error) { + if m.serverReadHeader != nil { + return m.serverReadHeader(ctx, msg) + } + return ctx, nil +} + +func (m *mockFrameMetaHandler) WriteHeader(ctx context.Context, message remote.Message) error { + if m.writeHeader != nil { + return m.writeHeader(ctx, message) + } + return nil +} + +func (m *mockFrameMetaHandler) ReadTrailer(ctx context.Context, message remote.Message) error { + if m.readTrailer != nil { + return m.readTrailer(ctx, message) + } + return nil +} + +func (m *mockFrameMetaHandler) WriteTrailer(ctx context.Context, message remote.Message) error { + if m.writeTrailer != nil { + return m.writeTrailer(ctx, message) + } + return nil +} + +func (m *mockFrameMetaHandler) OnConnectStream(ctx context.Context) (context.Context, error) { + return m.onConnectStream(ctx) +} + +var _ netpollReader = (*mockConn)(nil) + +type mockConn struct { + readBuf []byte + readPos int + read func(b []byte) (int, error) + write func(b []byte) (int, error) + remoteAddr net.Addr + localAddr net.Addr + closeErr error +} + +func (m *mockConn) Reader() netpoll.Reader { + return m +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.read != nil { + return m.read(b) + } + available := len(m.readBuf) - m.readPos + if available <= 0 { + return 0, io.EOF + } + savedPos := m.readPos + m.readPos += len(b) + if m.readPos-savedPos >= available { + m.readPos = len(m.readBuf) + } + return copy(b, m.readBuf[savedPos:m.readPos]), nil +} + +func (m *mockConn) Write(b []byte) (n int, err error) { + if m.write != nil { + return m.write(b) + } + return len(b), nil // fake success +} + +func (m *mockConn) Close() error { + return m.closeErr +} + +func (m *mockConn) LocalAddr() net.Addr { + return m.localAddr +} + +func (m *mockConn) RemoteAddr() net.Addr { + return m.remoteAddr +} + +func (m *mockConn) SetDeadline(t time.Time) error { + return nil +} + +func (m *mockConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (m *mockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// Next implements netpoll.Reader +func (m *mockConn) Next(n int) (p []byte, err error) { + // TODO implement me + panic("implement me") +} + +// Peek implements netpoll.Reader +func (m *mockConn) Peek(n int) (buf []byte, err error) { + available := len(m.readBuf) - m.readPos + if available < n { + return nil, io.EOF + } + return m.readBuf[m.readPos : m.readPos+n], nil +} + +// Skip implements netpoll.Reader +func (m *mockConn) Skip(n int) (err error) { + // TODO implement me + panic("implement me") +} + +// Until implements netpoll.Reader +func (m *mockConn) Until(delim byte) (line []byte, err error) { + // TODO implement me + panic("implement me") +} + +// ReadString implements netpoll.Reader +func (m *mockConn) ReadString(n int) (s string, err error) { + // TODO implement me + panic("implement me") +} + +// ReadBinary implements netpoll.Reader +func (m *mockConn) ReadBinary(n int) (p []byte, err error) { + // TODO implement me + panic("implement me") +} + +// ReadByte implements netpoll.Reader +func (m *mockConn) ReadByte() (b byte, err error) { + // TODO implement me + panic("implement me") +} + +// Slice implements netpoll.Reader +func (m *mockConn) Slice(n int) (r netpoll.Reader, err error) { + // TODO implement me + panic("implement me") +} + +// Release implements netpoll.Reader +func (m *mockConn) Release() (err error) { + // TODO implement me + panic("implement me") +} + +// Len implements netpoll.Reader +func (m *mockConn) Len() (length int) { + // TODO implement me + panic("implement me") +} + +type mockByteBuffer struct { + remote.ByteBuffer + conn net.Conn + bufList [][]byte +} + +func (m *mockByteBuffer) Malloc(n int) ([]byte, error) { + buf := make([]byte, n) + m.bufList = append(m.bufList, buf) + return buf, nil +} + +func (m *mockByteBuffer) MallocLen() (total int) { + for _, buf := range m.bufList { + total += len(buf) + } + return +} + +func (m *mockByteBuffer) Flush() (err error) { + fullBuffer := make([]byte, 0, m.MallocLen()) + for _, buf := range m.bufList { + fullBuffer = append(fullBuffer, buf...) + } + _, err = m.conn.Write(fullBuffer) + return +} + +func (m *mockByteBuffer) Next(n int) (p []byte, err error) { + p = make([]byte, n) + for idx := 0; idx < n; { + c, err := m.conn.Read(p[idx:]) + if err != nil { + return nil, err + } + idx += c + } + return p, nil +} + +func (m *mockByteBuffer) Skip(n int) (err error) { + _, err = m.Next(n) + return +} + +func (m *mockByteBuffer) WriteString(s string) (n int, err error) { + m.bufList = append(m.bufList, []byte(s)) + return len(s), nil +} + +func (m *mockByteBuffer) WriteBinary(p []byte) (n int, err error) { + m.bufList = append(m.bufList, p) + return len(p), nil +} + +type mockExtension struct{} + +func (m *mockExtension) SetReadTimeout(ctx context.Context, conn net.Conn, cfg rpcinfo.RPCConfig, role remote.RPCRole) { +} + +func (m *mockExtension) NewWriteByteBuffer(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { + return &mockByteBuffer{conn: conn} +} + +func (m *mockExtension) NewReadByteBuffer(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { + return &mockByteBuffer{conn: conn} +} + +func (m *mockExtension) ReleaseBuffer(buffer remote.ByteBuffer, err error) error { + // TODO implement me + panic("implement me") +} + +func (m *mockExtension) IsTimeoutErr(err error) bool { + // TODO implement me + panic("implement me") +} + +func (m *mockExtension) IsRemoteClosedErr(err error) bool { + // TODO implement me + panic("implement me") +} + +type mockStream struct { + streaming.Stream + doCloseSend func(error) error + doRecvMsg func(msg interface{}) error + doTrailer func() metadata.MD +} + +func (m *mockStream) closeSend(err error) error { + return m.doCloseSend(err) +} + +func (m *mockStream) RecvMsg(msg interface{}) error { + return m.doRecvMsg(msg) +} + +func (m *mockStream) Trailer() metadata.MD { + if m.doTrailer != nil { + return m.doTrailer() + } + return nil +} + +func decodeMessage(ctx context.Context, buf []byte, data interface{}, role remote.RPCRole) (remote.Message, error) { + return decodeMessageFromBuffer(ctx, remote.NewReaderBuffer(buf), data, role) +} + +func decodeMessageFromBuffer(ctx context.Context, in remote.ByteBuffer, data interface{}, role remote.RPCRole) (remote.Message, error) { + ri := rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo("", "", nil, nil), + rpcinfo.NewMutableEndpointInfo("", "", nil, nil).ImmutableView(), + rpcinfo.NewInvocation("", ""), + rpcinfo.NewRPCConfig(), + rpcinfo.NewRPCStats(), + ) + msg := remote.NewMessage( + data, + nil, + ri, + remote.Stream, + role, + ) + err := ttheader.NewStreamCodec().Decode(ctx, msg, in) + return msg, err +} + +func encodeMessage(ctx context.Context, msg remote.Message) ([]byte, error) { + out := remote.NewWriterBuffer(256) + if err := ttheader.NewStreamCodec().Encode(ctx, msg, out); err != nil { + return nil, err + } + return out.Bytes() +} diff --git a/pkg/remote/trans/ttheaderstreaming/server_handler.go b/pkg/remote/trans/ttheaderstreaming/server_handler.go new file mode 100644 index 0000000000..f2c7fbd108 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/server_handler.go @@ -0,0 +1,289 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "fmt" + "math" + "net" + "runtime/debug" + "time" + + "github.com/cloudwego/localsession/backup" + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/ttheader" + "github.com/cloudwego/kitex/pkg/remote/trans" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/stats" + "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/transport" +) + +const ( + detectionSize = 8 + infinityTimeout = time.Duration(math.MaxInt64) +) + +var ( + _ remote.ServerTransHandler = (*ttheaderStreamingTransHandler)(nil) + _ remote.InvokeHandleFuncSetter = (*ttheaderStreamingTransHandler)(nil) + + ErrNotTTHeaderStreaming = errors.New("not ttheader streaming") +) + +// NewSvrTransHandler returns a new server transHandler with given extension (netpoll, gonet, etc.) +func NewSvrTransHandler(opt *remote.ServerOption, ext trans.Extension) (remote.ServerTransHandler, error) { + return &ttheaderStreamingTransHandler{ + opt: opt, + codec: ttheader.NewStreamCodec(), + ext: ext, + frameMetaHandler: getServerFrameMetaHandler(opt.TTHeaderFrameMetaHandler), + }, nil +} + +func getServerFrameMetaHandler(handler remote.FrameMetaHandler) remote.FrameMetaHandler { + if handler != nil { + return handler + } + return NewServerTTHeaderFrameHandler() +} + +type ttheaderStreamingTransHandler struct { + opt *remote.ServerOption + codec remote.Codec + invokeHandler endpoint.Endpoint + ext trans.Extension + frameMetaHandler remote.FrameMetaHandler +} + +// ProtocolMatch implements the detection.DetectableServerTransHandler interface. +func (t *ttheaderStreamingTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) (err error) { + preface, err := getReader(conn).Peek(detectionSize) + if err != nil { + return err + } + if codec.IsTTHeaderStreaming(preface) { + return nil + } + return ErrNotTTHeaderStreaming +} + +// SetInvokeHandleFunc implements the remote.InvokeHandleFuncSetter interface. +func (t *ttheaderStreamingTransHandler) SetInvokeHandleFunc(invokeHandler endpoint.Endpoint) { + t.invokeHandler = invokeHandler +} + +func (t *ttheaderStreamingTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { + var detailedError *kerrors.DetailedError + if ok := errors.As(err, &detailedError); ok && detailedError.Stack() != "" { + klog.CtxErrorf(ctx, "KITEX: processing ttheader streaming request error, remoteAddr=%s, error=%s\nstack=%s", + getRemoteAddr(ctx, conn), err.Error(), detailedError.Stack()) + } else { + klog.CtxErrorf(ctx, "KITEX: processing ttheader streaming request error, remoteAddr=%s, error=%s", + getRemoteAddr(ctx, conn), err.Error()) + } +} + +// Write should not be called in ttheaderStreamingTransHandler +func (t *ttheaderStreamingTransHandler) Write(ctx context.Context, conn net.Conn, send remote.Message) (context.Context, error) { + panic("not used") +} + +// Read should not be called in ttheaderStreamingTransHandler +func (t *ttheaderStreamingTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (context.Context, error) { + panic("not used") +} + +// OnMessage should not be called in ttheaderStreamingTransHandler +func (t *ttheaderStreamingTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) { + panic("not used") +} + +// SetPipeline is not used in ttheaderStreamingTransHandler +func (t *ttheaderStreamingTransHandler) SetPipeline(pipeline *remote.TransPipeline) { + // not used +} + +// OnActive sets the read timeout on the connection to infinity +func (t *ttheaderStreamingTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { + // set readTimeout to infinity to avoid streaming break + // use keepalive to check the health of connection + if npConn, ok := conn.(netpoll.Connection); ok { + _ = npConn.SetReadTimeout(infinityTimeout) + } else { + _ = conn.SetReadDeadline(time.Now().Add(infinityTimeout)) + } + return ctx, nil +} + +// OnInactive is not used in ttheaderStreamingTransHandler +func (t *ttheaderStreamingTransHandler) OnInactive(ctx context.Context, conn net.Conn) {} + +func (t *ttheaderStreamingTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) { + rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) + if rpcStats == nil { + return + } + if panicErr != nil { + rpcStats.SetPanicked(panicErr) + } + t.opt.TracerCtl.DoFinish(ctx, ri, err) + rpcStats.Reset() +} + +// OnRead handles a new TTHeader Stream +func (t *ttheaderStreamingTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) { + ctx, ri := t.newCtxWithRPCInfo(ctx, conn) + rawStream := newTTHeaderStream(ctx, conn, ri, t.codec, remote.Server, t.ext, t.frameMetaHandler) + + if ctx, err = rawStream.serverReadHeaderFrame(t.opt); err != nil { + if errors.Is(err, ErrDirtyFrame) { + // There could be dirty frames from the previous stream while the connection has been closed by the client. + // Just ignore the error (and drop the frame) since we're not starting a new stream at this moment. + // klog.CtxWarnf(ctx, "KITEX: TTHeader Streaming serverReadHeaderFrame got dirty frame") + return nil + } + t.OnError(ctx, err, conn) + return fmt.Errorf("KITEX: TTHeader Streaming serverReadHeaderFrame failed, error=%v", err) + } + + // tracer is started in rawStream.serverReadHeaderFrame() + defer func() { + panicErr := recover() + err = t.onEndStream(ctx, conn, ri, panicErr, err, rawStream) + }() + + st := endpoint.NewStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint) + // bind stream into ctx, for user to set grpc style metadata(headers/trailers) by provided api in meta_api.go + ctx = streaming.NewCtxWithStream(ctx, st) + + svcInfo, methodInfo := t.retrieveServiceMethodInfo(ri) + if methodInfo == nil { + err = t.invokeUnknownMethod(ctx, ri, st, svcInfo) + } else { + err = t.invokeHandler(ctx, &streaming.Args{Stream: st}, nil) + } + + if err != nil { + t.OnError(ctx, err, conn) + } + + // when error occurs, a non-nil error should be returned to close the connection; + // otherwise the connection will be reused for the next request which might contain unexpected data. + return err +} + +// interface makes it easier to test +type streamFinisher interface { + closeSend(err error) error +} + +func (t *ttheaderStreamingTransHandler) onEndStream( + ctx context.Context, conn net.Conn, ri rpcinfo.RPCInfo, panicInfo interface{}, err error, st streamFinisher, +) error { + if panicInfo != nil { + klog.CtxErrorf(ctx, + "KITEX: TTHeader Streaming panic happened in %s, remoteAddress=%v, error=%s\nstack=%s", + "ttheaderStreamingTransHandler.OnRead", getRemoteAddr(ctx, conn), panicInfo, string(debug.Stack()), + ) + } + t.finishTracer(ctx, ri, err, panicInfo) + if panicInfo != nil { + // make sure a non-nil err is returned to prevent reuse of the connection + err = fmt.Errorf("KITEX: TTHeader Streaming panic: %v", panicInfo) + } + + // always try to send a TrailerFrame after a new stream is accepted + if closeErr := st.closeSend(err); closeErr != nil { + t.OnError(ctx, closeErr, conn) + err = closeErr + } + + // if err == nil { + // // if no error, read until the TrailerFrame to before reusing the connection (avoid dirty frames) + // if readErr := rawStream.serverWaitTrailer(); readErr != nil { + // t.OnError(ctx, readErr, conn) + // err = readErr + // } + // } + // _ = rawStream.cleanup() + return err +} + +func (t *ttheaderStreamingTransHandler) newCtxWithRPCInfo(ctx context.Context, conn net.Conn) (context.Context, rpcinfo.RPCInfo) { + ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr()) + _ = rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.TTHeader) + return rpcinfo.NewCtxWithRPCInfo(ctx, ri), ri +} + +func (t *ttheaderStreamingTransHandler) invokeUnknownMethod( + ctx context.Context, ri rpcinfo.RPCInfo, st streaming.Stream, svcInfo *serviceinfo.ServiceInfo, +) (err error) { + if t.opt.GRPCUnknownServiceHandler != nil { + err = t.unknownMethodWithRecover(ctx, ri, st) + if err = coverBizStatusError(err, ri); err != nil { + err = kerrors.ErrBiz.WithCause(err) + } + return err + } + + if svcInfo == nil { + return remote.NewTransErrorWithMsg(remote.UnknownService, + fmt.Sprintf("unknown service %s", ri.Invocation().ServiceName())) + } + return remote.NewTransErrorWithMsg(remote.UnknownMethod, + fmt.Sprintf("unknown method %s", ri.Invocation().MethodName())) +} + +func (t *ttheaderStreamingTransHandler) unknownMethodWithRecover( + ctx context.Context, ri rpcinfo.RPCInfo, st streaming.Stream, +) (err error) { + defer func() { + rpcinfo.Record(ctx, ri, stats.ServerHandleStart, nil) + if panicInfo := recover(); panicInfo != nil { + err = kerrors.ErrPanic.WithCauseAndStack( + fmt.Errorf( + "[happened in biz handler, method=%s.%s, please check the panic at the server side] %s", + ri.Invocation().ServiceName(), ri.Invocation().MethodName(), panicInfo), + string(debug.Stack())) + rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) + rpcStats.SetPanicked(panicInfo) + } + rpcinfo.Record(ctx, ri, stats.ServerHandleFinish, err) + // clear session + backup.ClearCtx() + }() + return t.opt.GRPCUnknownServiceHandler(ctx, ri.Invocation().MethodName(), st) +} + +func (t *ttheaderStreamingTransHandler) retrieveServiceMethodInfo(ri rpcinfo.RPCInfo) (*serviceinfo.ServiceInfo, serviceinfo.MethodInfo) { + svc, method := ri.Invocation().ServiceName(), ri.Invocation().MethodName() + svcInfo := t.opt.SvcSearchMap[remote.BuildMultiServiceKey(svc, method)] + if svcInfo == nil { + return nil, nil + } + return svcInfo, svcInfo.MethodInfo(method) +} diff --git a/pkg/remote/trans/ttheaderstreaming/server_handler_test.go b/pkg/remote/trans/ttheaderstreaming/server_handler_test.go new file mode 100644 index 0000000000..29ce3e9c37 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/server_handler_test.go @@ -0,0 +1,748 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "net" + "reflect" + "strings" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/ttheader" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/stats" + "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/transport" +) + +func Test_ttheaderStreamingTransHandler_ProtocolMatch(t *testing.T) { + t.Run("peek-err", func(t *testing.T) { + opt := &remote.ServerOption{} + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + err := svrHandler.ProtocolMatch(context.Background(), &mockConn{ + readBuf: make([]byte, 0), + }) + test.Assert(t, err != nil, err) + }) + t.Run("match", func(t *testing.T) { + opt := &remote.ServerOption{} + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + buf := []byte{0, 0, 0, 0, 0x10, 0x00, 0x00, 0x02} + err := svrHandler.ProtocolMatch(context.Background(), &mockConn{ + readBuf: buf, + }) + test.Assert(t, err == nil, err) + }) + t.Run("unmatched", func(t *testing.T) { + opt := &remote.ServerOption{} + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + buf := []byte{0, 0, 0, 0, 0x10, 0x00, 0x00, 0x00} + err := svrHandler.ProtocolMatch(context.Background(), &mockConn{ + readBuf: buf, + }) + test.Assert(t, errors.Is(err, ErrNotTTHeaderStreaming), err) + }) +} + +func Test_getServerMetaHandler(t *testing.T) { + t.Run("nil", func(t *testing.T) { + h := getServerFrameMetaHandler(nil) + _, ok := h.(*serverTTHeaderFrameHandler) + test.Assert(t, ok, h) + }) + t.Run("non-nil", func(t *testing.T) { + h := getServerFrameMetaHandler(&mockFrameMetaHandler{}) + _, ok := h.(*mockFrameMetaHandler) + test.Assert(t, ok, h) + }) +} + +type mockTracer struct { + start func(ctx context.Context) context.Context + finish func(ctx context.Context) +} + +func (m *mockTracer) Start(ctx context.Context) context.Context { + if m.start != nil { + return m.start(ctx) + } + return ctx +} + +func (m *mockTracer) Finish(ctx context.Context) { + if m.finish != nil { + m.finish(ctx) + } +} + +func Test_ttheaderStreamingTransHandler_finishTracer(t *testing.T) { + t.Run("normal", func(t *testing.T) { + called := false + ctrl := &rpcinfo.TraceController{} + ctrl.Append(&mockTracer{ + finish: func(ctx context.Context) { + ri := rpcinfo.GetRPCInfo(ctx) + event := ri.Stats().GetEvent(stats.RPCFinish) + test.Assert(t, event != nil && event.Event() == stats.RPCFinish, event) + called = true + }, + }) + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelBase) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + svrHandler.finishTracer(ctx, ri, nil, nil) + test.Assert(t, called, called) + }) + t.Run("err", func(t *testing.T) { + err := errors.New("err") + called := false + ctrl := &rpcinfo.TraceController{} + ctrl.Append(&mockTracer{ + finish: func(ctx context.Context) { + ri := rpcinfo.GetRPCInfo(ctx) + event := ri.Stats().GetEvent(stats.RPCFinish) + test.Assert(t, event != nil && event.Event() == stats.RPCFinish, event) + test.Assert(t, event.Info() == err.Error(), event) + called = true + }, + }) + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelBase) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + svrHandler.finishTracer(ctx, ri, err, nil) + test.Assert(t, called, called) + }) + t.Run("panic", func(t *testing.T) { + called := false + ctrl := &rpcinfo.TraceController{} + ctrl.Append(&mockTracer{ + finish: func(ctx context.Context) { + ri := rpcinfo.GetRPCInfo(ctx) + event := ri.Stats().GetEvent(stats.RPCFinish) + test.Assert(t, event != nil && event.Event() == stats.RPCFinish, event) + panicked, info := ri.Stats().Panicked() + test.Assert(t, panicked && info == "panic", panicked, info) + called = true + }, + }) + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelBase) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + svrHandler.finishTracer(ctx, ri, nil, "panic") + test.Assert(t, called, called) + }) + + t.Run("no-rpcStats", func(t *testing.T) { + called := false + ctrl := &rpcinfo.TraceController{} + ctrl.Append(&mockTracer{ + finish: func(ctx context.Context) { + called = true + }, + }) + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, nil) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + svrHandler.finishTracer(ctx, ri, nil, nil) + test.Assert(t, !called, called) + }) +} + +func Test_ttheaderStreamingTransHandler_retrieveServiceMethodInfo(t *testing.T) { + h, _ := NewSvrTransHandler(&remote.ServerOption{ + SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey("service", "method"): { + ServiceName: "service", + Methods: map[string]serviceinfo.MethodInfo{ + "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), + }, + }, + }, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + svcInfo, methodInfo := svrHandler.retrieveServiceMethodInfo(ri) + test.Assert(t, svcInfo != nil, svcInfo) + test.Assert(t, methodInfo != nil, methodInfo) +} + +func Test_ttheaderStreamingTransHandler_unknownMethodWithRecover(t *testing.T) { + t.Run("normal", func(t *testing.T) { + called := false + opt := &remote.ServerOption{ + GRPCUnknownServiceHandler: func(ctx context.Context, method string, stream streaming.Stream) error { + test.Assert(t, method == "method", method) + called = true + return nil + }, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, stat) + err := svrHandler.unknownMethodWithRecover(context.Background(), ri, nil) + test.Assert(t, err == nil, err) + test.Assert(t, called, called) + test.Assert(t, stat.GetEvent(stats.ServerHandleStart) != nil, stat) + event := stat.GetEvent(stats.ServerHandleFinish) + test.Assert(t, event.Status() == stats.StatusInfo, event) + panicked, _ := stat.Panicked() + test.Assert(t, !panicked, panicked) + }) + t.Run("panic", func(t *testing.T) { + called := false + opt := &remote.ServerOption{ + GRPCUnknownServiceHandler: func(ctx context.Context, method string, stream streaming.Stream) error { + called = true + panic("panic") + }, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, stat) + err := svrHandler.unknownMethodWithRecover(context.Background(), ri, nil) + test.Assert(t, err != nil, err) + test.Assert(t, called, called) + event := stat.GetEvent(stats.ServerHandleFinish) + test.Assert(t, event.Status() == stats.StatusError, event) + panicked, panicInfo := stat.Panicked() + test.Assert(t, panicked && panicInfo == "panic", panicked, panicInfo) + }) +} + +func Test_ttheaderStreamingTransHandler_invokeUnknownMethod(t *testing.T) { + t.Run("unknown-method:no-err", func(t *testing.T) { + opt := &remote.ServerOption{ + GRPCUnknownServiceHandler: func(ctx context.Context, method string, stream streaming.Stream) error { + return nil + }, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + err := svrHandler.invokeUnknownMethod(ctx, ri, nil, nil) + + test.Assert(t, err == nil, err) + }) + t.Run("unknown-method:normal-err", func(t *testing.T) { + handlerErr := errors.New("error") + opt := &remote.ServerOption{ + GRPCUnknownServiceHandler: func(ctx context.Context, method string, stream streaming.Stream) error { + return handlerErr + }, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + err := svrHandler.invokeUnknownMethod(ctx, ri, nil, nil) + + var detailedErr *kerrors.DetailedError + test.Assert(t, errors.As(err, &detailedErr), detailedErr) + test.Assert(t, detailedErr.Is(kerrors.ErrBiz)) + test.Assert(t, errors.Is(handlerErr, detailedErr.Unwrap())) + }) + + t.Run("unknown-method:biz-status-err", func(t *testing.T) { + handlerErr := kerrors.NewBizStatusError(100, "error") + opt := &remote.ServerOption{ + GRPCUnknownServiceHandler: func(ctx context.Context, method string, stream streaming.Stream) error { + return handlerErr + }, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + err := svrHandler.invokeUnknownMethod(ctx, ri, nil, nil) + + test.Assert(t, err == nil, err) + bizErr := ri.Invocation().BizStatusErr() + test.Assert(t, bizErr != nil, bizErr) + test.Assert(t, bizErr.BizStatusCode() == handlerErr.BizStatusCode()) + test.Assert(t, bizErr.BizMessage() == handlerErr.BizMessage()) + }) + + t.Run("no-unknown-method,no-svc-info", func(t *testing.T) { + opt := &remote.ServerOption{ + GRPCUnknownServiceHandler: nil, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + err := svrHandler.invokeUnknownMethod(ctx, ri, nil, nil) + + test.Assert(t, err != nil, err) + var transErr *remote.TransError + errors.As(err, &transErr) + test.Assert(t, transErr.TypeID() == remote.UnknownService, transErr.TypeID()) + }) + + t.Run("no-unknown-method,has-svc-info", func(t *testing.T) { + opt := &remote.ServerOption{ + GRPCUnknownServiceHandler: nil, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + ivk := rpcinfo.NewInvocation("service", "method") + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelDetailed) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + err := svrHandler.invokeUnknownMethod(ctx, ri, nil, &serviceinfo.ServiceInfo{}) + + test.Assert(t, err != nil, err) + var transErr *remote.TransError + errors.As(err, &transErr) + test.Assert(t, transErr.TypeID() == remote.UnknownMethod, transErr.TypeID()) + }) +} + +func Test_ttheaderStreamingTransHandler_newCtxWithRPCInfo(t *testing.T) { + expectedAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8888") + opt := &remote.ServerOption{ + InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + test.Assert(t, reflect.DeepEqual(addr, expectedAddr), addr) + from := rpcinfo.NewEndpointInfo("from-service", "from-method", addr, nil) + cfg := rpcinfo.NewRPCConfig() + ri := rpcinfo.NewRPCInfo(from, nil, nil, cfg, nil) + return ri + }, + } + h, _ := NewSvrTransHandler(opt, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + conn := &mockConn{ + remoteAddr: expectedAddr, + } + ctx, ri := svrHandler.newCtxWithRPCInfo(context.Background(), conn) + test.Assert(t, rpcinfo.GetRPCInfo(ctx) != nil) + test.Assert(t, ri.Config().TransportProtocol() == transport.TTHeader, ri) +} + +func Test_ttheaderStreamingTransHandler_onEndStream(t *testing.T) { + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8888") + conn := &mockConn{ + remoteAddr: addr, + } + + t.Run("normal", func(t *testing.T) { + ctrl := &rpcinfo.TraceController{} + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelBase) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + gotErr := errors.New("got error") + st := &mockStream{ + doCloseSend: func(err error) error { + gotErr = err + return nil + }, + } + err := svrHandler.onEndStream(ctx, conn, ri, nil, nil, st) + test.Assert(t, err == nil, err) + test.Assert(t, gotErr == nil, gotErr) + }) + + t.Run("err", func(t *testing.T) { + ctrl := &rpcinfo.TraceController{} + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelBase) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + invokeErr := errors.New("invokeErr") + gotErr := errors.New("got error") + st := &mockStream{ + doCloseSend: func(err error) error { + gotErr = err + return nil + }, + } + err := svrHandler.onEndStream(ctx, conn, ri, nil, invokeErr, st) + test.Assert(t, errors.Is(err, invokeErr), err) + test.Assert(t, errors.Is(gotErr, invokeErr), gotErr) + }) + + t.Run("panic", func(t *testing.T) { + ctrl := &rpcinfo.TraceController{} + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelBase) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + gotErr := errors.New("got error") + st := &mockStream{ + doCloseSend: func(err error) error { + gotErr = err + return nil + }, + } + err := svrHandler.onEndStream(ctx, conn, ri, "panic", nil, st) + test.Assert(t, strings.Contains(err.Error(), "panic"), err) + test.Assert(t, strings.Contains(gotErr.Error(), "panic"), gotErr) + }) + + t.Run("normal+closeSendErr", func(t *testing.T) { + ctrl := &rpcinfo.TraceController{} + h, _ := NewSvrTransHandler(&remote.ServerOption{ + TracerCtl: ctrl, + }, nil) + svrHandler := h.(*ttheaderStreamingTransHandler) + stat := rpcinfo.NewRPCStats() + rpcinfo.AsMutableRPCStats(stat).SetLevel(stats.LevelBase) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, stat) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + closeSendErr := errors.New("closeSendErr") + st := &mockStream{ + doCloseSend: func(err error) error { + return closeSendErr + }, + } + err := svrHandler.onEndStream(ctx, conn, ri, nil, nil, st) + test.Assert(t, errors.Is(err, closeSendErr), err) + }) +} + +func Test_ttheaderStreamingTransHandler_OnRead(t *testing.T) { + idlServiceName := "idl-service" + toMethodName := "to-method" + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8888") + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: "xxx", + PayloadCodec: serviceinfo.Protobuf, + Methods: map[string]serviceinfo.MethodInfo{ + toMethodName: serviceinfo.NewMethodInfo( + nil, nil, nil, false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + }, + } + + t.Run("normal", func(t *testing.T) { + opt := &remote.ServerOption{ + Option: remote.Option{}, + InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + from := rpcinfo.NewEndpointInfo("from-service", "from-method", addr, nil) + to := rpcinfo.NewMutableEndpointInfo("to-service", toMethodName, addr, nil) + ivk := rpcinfo.NewInvocation(idlServiceName, toMethodName) + return rpcinfo.NewRPCInfo(from, to.ImmutableView(), ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + }, + TracerCtl: &rpcinfo.TraceController{}, + SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(idlServiceName, toMethodName): svcInfo, + }, + } + ext := &mockExtension{} + h, err := NewSvrTransHandler(opt, ext) + test.Assert(t, err == nil, err) + svrHandler := h.(*ttheaderStreamingTransHandler) + invokeHandlerCalled := false + svrHandler.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + invokeHandlerCalled = true + return nil + }) + + writeCnt := int32(0) + + conn := &mockConn{ + readBuf: func() []byte { + // create a client HeaderFrame for read + ivk := rpcinfo.NewInvocation(idlServiceName, toMethodName) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = ivk.MethodName() + msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName] = ivk.ServiceName() + out := remote.NewWriterBuffer(1024) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + if err = ttheader.NewStreamCodec().Encode(ctx, msg, out); err != nil { + panic(err) + } + buf, _ := out.Bytes() + return buf + }(), + write: func(b []byte) (int, error) { + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, nil) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Server) + in := remote.NewReaderBuffer(b) + if err := ttheader.NewStreamCodec().Decode(context.Background(), msg, in); err != nil { + panic(err) + } + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader) + } else if writeCnt == 1 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer) + } else { + test.Assert(t, false, "writeCnt should be 0 or 1") + } + writeCnt += 1 + return len(b), nil + }, + remoteAddr: addr, + } + + ctx := context.Background() + err = svrHandler.OnRead(ctx, conn) + test.Assert(t, err == nil, err) + test.Assert(t, invokeHandlerCalled, invokeHandlerCalled) + test.Assert(t, writeCnt == 2, writeCnt) + }) + + t.Run("dirty-frame:ignore", func(t *testing.T) { + opt := &remote.ServerOption{ + Option: remote.Option{}, + InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + from := rpcinfo.NewEndpointInfo("from-service", "from-method", addr, nil) + to := rpcinfo.NewMutableEndpointInfo("to-service", toMethodName, addr, nil) + ivk := rpcinfo.NewInvocation(idlServiceName, toMethodName) + return rpcinfo.NewRPCInfo(from, to.ImmutableView(), ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + }, + TracerCtl: &rpcinfo.TraceController{}, + SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(idlServiceName, toMethodName): svcInfo, + }, + } + ext := &mockExtension{} + h, err := NewSvrTransHandler(opt, ext) + test.Assert(t, err == nil, err) + svrHandler := h.(*ttheaderStreamingTransHandler) + invokeHandlerCalled := false + svrHandler.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + invokeHandlerCalled = true + return nil + }) + + conn := &mockConn{ + readBuf: func() []byte { + // create a client HeaderFrame for read + ivk := rpcinfo.NewInvocation(idlServiceName, toMethodName) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer // Dirty Frame + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = ivk.MethodName() + msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName] = ivk.ServiceName() + out := remote.NewWriterBuffer(1024) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + if err = ttheader.NewStreamCodec().Encode(ctx, msg, out); err != nil { + panic(err) + } + buf, _ := out.Bytes() + return buf + }(), + remoteAddr: addr, + } + + ctx := context.Background() + err = svrHandler.OnRead(ctx, conn) + test.Assert(t, err == nil, err) + test.Assert(t, !invokeHandlerCalled, invokeHandlerCalled) + }) + + t.Run("read-header-failed", func(t *testing.T) { + opt := &remote.ServerOption{ + Option: remote.Option{}, + InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + from := rpcinfo.NewEndpointInfo("from-service", "from-method", addr, nil) + to := rpcinfo.NewMutableEndpointInfo("to-service", toMethodName, addr, nil) + ivk := rpcinfo.NewInvocation(idlServiceName, toMethodName) + return rpcinfo.NewRPCInfo(from, to.ImmutableView(), ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + }, + TracerCtl: &rpcinfo.TraceController{}, + SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(idlServiceName, toMethodName): svcInfo, + }, + } + ext := &mockExtension{} + h, err := NewSvrTransHandler(opt, ext) + test.Assert(t, err == nil, err) + svrHandler := h.(*ttheaderStreamingTransHandler) + invokeHandlerCalled := false + svrHandler.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + invokeHandlerCalled = true + return nil + }) + + conn := &mockConn{ + remoteAddr: addr, + readBuf: func() []byte { + // create a client HeaderFrame for read + ivk := rpcinfo.NewInvocation(idlServiceName, toMethodName) + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = ivk.MethodName() + msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName] = ivk.ServiceName() + out := remote.NewWriterBuffer(1024) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + if err = ttheader.NewStreamCodec().Encode(ctx, msg, out); err != nil { + panic(err) + } + buf, _ := out.Bytes() + return buf[0 : len(buf)-1] // short read + }(), + } + + ctx := context.Background() + err = svrHandler.OnRead(ctx, conn) + test.Assert(t, strings.HasSuffix(err.Error(), "error=EOF"), err.Error()) + test.Assert(t, !invokeHandlerCalled, invokeHandlerCalled) + }) + + t.Run("unknown-method", func(t *testing.T) { + unknownMethodCalled := false + opt := &remote.ServerOption{ + Option: remote.Option{}, + InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + from := rpcinfo.NewEndpointInfo("from-service", "from-method", addr, nil) + to := rpcinfo.NewMutableEndpointInfo("to-service", toMethodName, addr, nil) + ivk := rpcinfo.NewInvocation(idlServiceName, toMethodName) + return rpcinfo.NewRPCInfo(from, to.ImmutableView(), ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + }, + TracerCtl: &rpcinfo.TraceController{}, + SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ + remote.BuildMultiServiceKey(idlServiceName, toMethodName): svcInfo, + }, + GRPCUnknownServiceHandler: func(ctx context.Context, method string, stream streaming.Stream) error { + unknownMethodCalled = true + return nil + }, + } + ext := &mockExtension{} + h, err := NewSvrTransHandler(opt, ext) + test.Assert(t, err == nil, err) + svrHandler := h.(*ttheaderStreamingTransHandler) + invokeHandlerCalled := false + svrHandler.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + invokeHandlerCalled = true + return nil + }) + + writeCnt := int32(0) + + conn := &mockConn{ + readBuf: func() []byte { + // create a client HeaderFrame for read + ivk := rpcinfo.NewInvocation(idlServiceName, "unknown-method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = ivk.MethodName() + msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName] = ivk.ServiceName() + out := remote.NewWriterBuffer(1024) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + if err = ttheader.NewStreamCodec().Encode(ctx, msg, out); err != nil { + panic(err) + } + buf, _ := out.Bytes() + return buf + }(), + write: func(b []byte) (int, error) { + ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, nil) + msg := remote.NewMessage(nil, svcInfo, ri, remote.Stream, remote.Server) + in := remote.NewReaderBuffer(b) + if err := ttheader.NewStreamCodec().Decode(context.Background(), msg, in); err != nil { + panic(err) + } + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader) + } else if writeCnt == 1 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer) + } else { + test.Assert(t, false, "writeCnt should be 0 or 1") + } + writeCnt += 1 + return len(b), nil + }, + remoteAddr: addr, + } + + ctx := context.Background() + err = svrHandler.OnRead(ctx, conn) + test.Assert(t, err == nil, err) + test.Assert(t, !invokeHandlerCalled, invokeHandlerCalled) + test.Assert(t, unknownMethodCalled, unknownMethodCalled) + test.Assert(t, writeCnt == 2, writeCnt) + }) +} diff --git a/pkg/remote/trans/ttheaderstreaming/state.go b/pkg/remote/trans/ttheaderstreaming/state.go new file mode 100644 index 0000000000..a4ffc6ba2a --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/state.go @@ -0,0 +1,75 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import "sync/atomic" + +const ( + stateNew uint32 = 0 + stateSet uint32 = 1 + received = stateSet // alias + sent = stateSet // alias + closed = stateSet +) + +type state struct { + meta uint32 + header uint32 + data uint32 + trailer uint32 + closed uint32 +} + +func (s *state) setMeta() (old uint32) { + return atomic.SwapUint32(&s.meta, stateSet) +} + +func (s *state) setHeader() (old uint32) { + return atomic.SwapUint32(&s.header, stateSet) +} + +func (s *state) setData() (old uint32) { + return atomic.SwapUint32(&s.data, stateSet) +} + +func (s *state) setTrailer() (old uint32) { + return atomic.SwapUint32(&s.trailer, stateSet) +} + +func (s *state) setClosed() (old uint32) { + return atomic.SwapUint32(&s.closed, stateSet) +} + +func (s *state) hasMeta() bool { + return atomic.LoadUint32(&s.meta) == stateSet +} + +func (s *state) hasHeader() bool { + return atomic.LoadUint32(&s.header) == stateSet +} + +func (s *state) hasData() bool { + return atomic.LoadUint32(&s.data) == stateSet +} + +func (s *state) hasTrailer() bool { + return atomic.LoadUint32(&s.trailer) == stateSet +} + +func (s *state) isClosed() bool { + return atomic.LoadUint32(&s.closed) == stateSet +} diff --git a/pkg/remote/trans/ttheaderstreaming/state_test.go b/pkg/remote/trans/ttheaderstreaming/state_test.go new file mode 100644 index 0000000000..b58d3a845a --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/state_test.go @@ -0,0 +1,58 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func Test_state_setMeta(t *testing.T) { + s := state{} + test.Assert(t, s.setMeta() == stateNew) + test.Assert(t, s.hasMeta()) + test.Assert(t, s.setMeta() == stateSet) +} + +func Test_state_setHeader(t *testing.T) { + s := state{} + test.Assert(t, s.setHeader() == stateNew) + test.Assert(t, s.hasHeader()) + test.Assert(t, s.setHeader() == stateSet) +} + +func Test_state_setData(t *testing.T) { + s := state{} + test.Assert(t, s.setData() == stateNew) + test.Assert(t, s.hasData()) + test.Assert(t, s.setData() == stateSet) +} + +func Test_state_setTrailer(t *testing.T) { + s := state{} + test.Assert(t, s.setTrailer() == stateNew) + test.Assert(t, s.hasTrailer()) + test.Assert(t, s.setTrailer() == stateSet) +} + +func Test_state_setClosed(t *testing.T) { + s := state{} + test.Assert(t, s.setClosed() == stateNew) + test.Assert(t, s.isClosed()) + test.Assert(t, s.setClosed() == stateSet) +} diff --git a/pkg/remote/trans/ttheaderstreaming/stream.go b/pkg/remote/trans/ttheaderstreaming/stream.go new file mode 100644 index 0000000000..89e3bf6e97 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/stream.go @@ -0,0 +1,581 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "runtime/debug" + "sync/atomic" + + "github.com/cloudwego/kitex/pkg/gofunc" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + remote_transmeta "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/transport" +) + +const ( + defaultSendQueueSize = 64 +) + +var ( + _ streaming.Stream = (*ttheaderStream)(nil) + + ErrIllegalHeaderWrite = errors.New("ttheader streaming: the stream is done or headers was already sent") + ErrInvalidFrame = errors.New("ttheader streaming: invalid frame") + ErrServerGotMetaFrame = errors.New("ttheader streaming: server got metadata frame") + ErrSendClosed = errors.New("ttheader streaming: send closed") + ErrDirtyFrame = errors.New("ttheader streaming: dirty frame") + + idAllocator = uint64(0) + + sendQueueSize = defaultSendQueueSize +) + +// SetSendQueueSize sets the send queue size. +func SetSendQueueSize(size int) { + sendQueueSize = size +} + +type sendRequest struct { + message remote.Message + finishSignal chan error +} + +func (r sendRequest) waitFinishSignal(ctx context.Context) error { + select { + case err := <-r.finishSignal: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +type ttheaderStream struct { + id uint64 + ctx context.Context + cancel context.CancelFunc + conn net.Conn + ri rpcinfo.RPCInfo + codec remote.Codec + rpcRole remote.RPCRole + ext trans.Extension + + recvState state + lastRecvError error + + sendState state + sendQueue chan sendRequest + sendLoopCloseSignal chan struct{} + + grpcMetadata grpcMetadata + frameHandler remote.FrameMetaHandler +} + +func newTTHeaderStream( + ctx context.Context, conn net.Conn, ri rpcinfo.RPCInfo, codec remote.Codec, + role remote.RPCRole, ext trans.Extension, handler remote.FrameMetaHandler, +) *ttheaderStream { + ctx, cancel := context.WithCancel(ctx) + st := &ttheaderStream{ + id: atomic.AddUint64(&idAllocator, 1), + ctx: ctx, + cancel: cancel, + conn: conn, + ri: ri, + codec: codec, + rpcRole: role, + ext: ext, + frameHandler: handler, + + sendQueue: make(chan sendRequest, sendQueueSize), + sendLoopCloseSignal: make(chan struct{}), + } + gofunc.GoFunc(ctx, func() { + // server will replace st.ctx, so sendLoop should not read st.ctx directly, to avoid concurrent read/write + st.sendLoop(ctx) + }) + return st +} + +func (t *ttheaderStream) Context() context.Context { + return t.ctx +} + +func (t *ttheaderStream) SetHeader(md metadata.MD) error { + if t.rpcRole != remote.Server { + panic("this method should only be used in server side stream!") + } + if t.sendState.hasHeader() || t.sendState.isClosed() { + return ErrIllegalHeaderWrite + } + t.grpcMetadata.setHeader(md) + return nil +} + +func (t *ttheaderStream) SendHeader(md metadata.MD) error { + if t.rpcRole != remote.Server { + panic("this method should only be used in server side stream!") + } + t.grpcMetadata.setHeader(md) + return t.sendHeaderFrame(nil) +} + +func (t *ttheaderStream) SetTrailer(md metadata.MD) { + if t.rpcRole != remote.Server { + panic("this method should only be used in server side stream!") + } + if md.Len() == 0 || t.sendState.isClosed() { + return + } + t.grpcMetadata.setTrailer(md) +} + +func (t *ttheaderStream) Header() (metadata.MD, error) { + if t.rpcRole != remote.Client { + panic("Header() should only be used in client side stream!") + } + if old := t.recvState.setHeader(); old != received { + if err := t.readUntilTargetFrame(nil, codec.FrameTypeHeader); err != nil { + return nil, err + } + } + return t.grpcMetadata.headersReceived, nil +} + +func (t *ttheaderStream) Trailer() metadata.MD { + if t.rpcRole != remote.Client { + panic("Trailer() should only be used in client side stream!") + } + if old := t.recvState.setTrailer(); old != received { + _ = t.readUntilTargetFrame(nil, codec.FrameTypeTrailer) + } + return t.grpcMetadata.trailerReceived +} + +func (t *ttheaderStream) RecvMsg(m interface{}) error { + if t.recvState.isClosed() { + return t.lastRecvError + } + readErr := t.readUntilTargetFrame(m, codec.FrameTypeData) + if readErr == io.EOF && t.ri.Invocation().BizStatusErr() != nil { + return nil // same behavior as grpc streaming: return nil for biz status error + } + return readErr +} + +func (t *ttheaderStream) SendMsg(m interface{}) error { + if t.sendState.isClosed() { + return ErrSendClosed + } + if err := t.sendHeaderFrame(nil); err != nil { + return err + } + + message := t.newMessageToSend(m, codec.FrameTypeData) + return t.putMessageToSendQueue(message) +} + +func (t *ttheaderStream) Close() error { + return t.closeSend(nil) +} + +func (t *ttheaderStream) waitSendLoopClose() { + <-t.sendLoopCloseSignal +} + +func (t *ttheaderStream) sendHeaderFrame(msg remote.Message) error { + if old := t.sendState.setHeader(); old == sent { + return nil + } + if msg == nil { // client header is constructed in client_handler.go + msg = t.newMessageToSend(nil, codec.FrameTypeHeader) + if err := t.frameHandler.WriteHeader(t.ctx, msg); err != nil { + return fmt.Errorf("server frameMetaHandler.WriteHeader failed: %w", err) + } + } + if err := t.grpcMetadata.injectHeader(msg); err != nil { + return fmt.Errorf("grpcMetadata.injectHeader failed: err=%w", err) + } + return t.putMessageToSendQueue(msg) +} + +func (t *ttheaderStream) sendTrailerFrame(invokeErr error) error { + if old := t.sendState.setTrailer(); old == sent { + return nil + } + if err := t.sendHeaderFrame(nil); err != nil { // make sure header has been sent + return err + } + message := t.newTrailerFrame(invokeErr) + return t.putMessageToSendQueue(message) +} + +func (t *ttheaderStream) closeRecv(err error) { + if old := t.recvState.setClosed(); old == closed { + return + } + if err == nil { + err = io.EOF + } + t.lastRecvError = err +} + +func (t *ttheaderStream) newMessageToSend(m interface{}, frameType string) remote.Message { + message := remote.NewMessage(m, nil, t.ri, remote.Stream, t.rpcRole) + message.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, t.ri.Config().PayloadCodec())) + message.Tags()[codec.HeaderFlagsKey] = codec.HeaderFlagsStreaming + message.TransInfo().TransIntInfo()[remote_transmeta.FrameType] = frameType + return message +} + +func (t *ttheaderStream) readMessageWatchingCtxDone(data interface{}) (message remote.Message, err error) { + frameReceived := make(chan struct{}, 1) + + var readErr error + gofunc.GoFunc(t.ctx, func() { + message, readErr = t.readMessage(data) + frameReceived <- struct{}{} + }) + + select { + case <-frameReceived: + return message, readErr + case <-t.ctx.Done(): + doneErr := t.ctx.Err() + t.closeRecv(doneErr) + return nil, doneErr + } +} + +func (t *ttheaderStream) readMessage(data interface{}) (message remote.Message, err error) { + defer func() { + if panicInfo := recover(); panicInfo != nil { + err = fmt.Errorf("readMessage panic: %v, stack=%s", panicInfo, string(debug.Stack())) + } + if err != nil { + t.closeRecv(err) + if message != nil { + message.Recycle() + message = nil + } + } + }() + if t.recvState.isClosed() { + return nil, t.lastRecvError + } + message = remote.NewMessage(data, nil, t.ri, remote.Stream, t.rpcRole) + bufReader := t.ext.NewReadByteBuffer(t.ctx, t.conn, message) + err = t.codec.Decode(t.ctx, message, bufReader) + return message, err +} + +func (t *ttheaderStream) parseMetaFrame(message remote.Message) (err error) { + if oldState := t.recvState.setMeta(); oldState == received { + return fmt.Errorf("unexpected meta frame, recvState=%v", t.recvState) + } + if err = t.frameHandler.ReadMeta(t.ctx, message); err != nil { + return fmt.Errorf("clientMetaHandler.ReadMeta failed: %w", err) + } + return err +} + +func (t *ttheaderStream) parseHeaderFrame(message remote.Message) (err error) { + t.recvState.setHeader() + if t.rpcRole == remote.Client { + if err = t.frameHandler.ClientReadHeader(t.ctx, message); err != nil { + return fmt.Errorf("frameHandler.ClientReadHeader failed: %w", err) + } + return t.grpcMetadata.parseHeader(message) + } else /* remote.Server */ { + // server header frame is already processed in server_handler.go + return nil + } +} + +func (t *ttheaderStream) parseDataFrame(message remote.Message) error { + if !t.recvState.hasHeader() { + return fmt.Errorf("unexpected data frame before header frame, recvState=%v", t.recvState) + } + return nil +} + +func (t *ttheaderStream) parseTrailerFrame(message remote.Message) (err error) { + defer func() { + t.closeRecv(err) + }() + + t.recvState.setTrailer() + if !t.recvState.hasHeader() { + return fmt.Errorf("unexpected trailer frame, recvState=%v", t.recvState) + } + + if err = t.frameHandler.ReadTrailer(t.ctx, message); err != nil { + var transErr *remote.TransError + if errors.As(err, &transErr) { + return transErr + } + return fmt.Errorf("frameHandler.ReadTrailer failed: %w", err) + } + + if t.rpcRole == remote.Client { + if err = t.grpcMetadata.parseTrailer(message); err != nil { + return fmt.Errorf("client grpcMetadata.parseTrailer failed: %w", err) + } + } + return io.EOF +} + +func (t *ttheaderStream) readAndParseMessage(data interface{}) (string, error) { + message, err := t.readMessage(data) + defer func() { + if message != nil { + message.Recycle() + } + }() + if err != nil { + return codec.FrameTypeInvalid, err + } + frameType := codec.MessageFrameType(message) + switch frameType { + case codec.FrameTypeMeta: + if err = t.parseMetaFrame(message); err != nil { + return frameType, err + } + case codec.FrameTypeHeader: + if err = t.parseHeaderFrame(message); err != nil { + return frameType, err + } + case codec.FrameTypeData: + if err = t.parseDataFrame(message); err != nil { + return frameType, err + } + case codec.FrameTypeTrailer: + if err = t.parseTrailerFrame(message); err != nil { + return frameType, err + } + default: + return frameType, ErrInvalidFrame + } + return frameType, nil +} + +func (t *ttheaderStream) readUntilTargetFrame(data interface{}, targetFrameType string) error { + frameReceived := make(chan struct{}, 1) + + var readErr error + gofunc.GoFunc(t.ctx, func() { + readErr = t.readUntilTargetFrameSynchronously(data, targetFrameType) + frameReceived <- struct{}{} + }) + + select { + case <-frameReceived: + return readErr + case <-t.ctx.Done(): + doneErr := t.ctx.Err() + t.closeRecv(doneErr) + return doneErr + } +} + +func (t *ttheaderStream) readUntilTargetFrameSynchronously(data interface{}, targetFrameType string) (err error) { + defer func() { + if err != nil { + t.closeRecv(err) + } + }() + var frameType string + for { + if frameType, err = t.readAndParseMessage(data); err != nil { + return err // err=io.EOF for TrailerFrame + } + if frameType == targetFrameType { + return nil + } + } +} + +func (t *ttheaderStream) putMessageToSendQueue(message remote.Message) error { + sendReq := sendRequest{ + message: message, + finishSignal: make(chan error, 1), + } + select { + case t.sendQueue <- sendReq: + return t.waitSendFinishedSignal(sendReq.finishSignal) + case <-t.ctx.Done(): + t.sendState.setClosed() // state change is enough here, since sendLoop will deal with the TrailerFrame + return t.ctx.Err() + case <-t.sendLoopCloseSignal: + return ErrSendClosed + } +} + +func (t *ttheaderStream) waitSendFinishedSignal(finishSignal chan error) error { + select { + case err := <-finishSignal: + return err + case <-t.ctx.Done(): + return t.ctx.Err() + case <-t.sendLoopCloseSignal: + return ErrSendClosed + } +} + +func (t *ttheaderStream) closeSend(err error) error { + if old := t.sendState.setClosed(); old == closed { + return nil + } + + sendErr := t.sendTrailerFrame(err) + t.waitSendLoopClose() // make sure all frames are sent before the connection is reused + if t.rpcRole == remote.Server { + t.closeRecv(ErrSendClosed) + t.cancel() // release blocking RecvMsg call + } + return sendErr +} + +func (t *ttheaderStream) newTrailerFrame(err error) remote.Message { + msg := t.newMessageToSend(err, codec.FrameTypeTrailer) + if err = t.frameHandler.WriteTrailer(t.ctx, msg); err != nil { + klog.Warnf("frameHandler.WriteTrailer failed: %v", err) + } + if err = t.grpcMetadata.injectTrailer(msg); err != nil { + klog.Warnf("grpcMetadata.injectTrailer failed: err=%v", err) + } + // always send the TrailerFrame regardless of the error + return msg +} + +func (t *ttheaderStream) sendLoop(ctx context.Context) { + defer func() { + if panicInfo := recover(); panicInfo != nil { + err := fmt.Errorf("sendLoop panic: info=%v, stack=%s", panicInfo, string(debug.Stack())) + klog.CtxWarnf(t.ctx, "KITEX: ttheader streaming sendLoop panic, info=%v, stack=%s", + panicInfo, string(debug.Stack())) + t.closeRecv(err) + t.cancel() // release any possible blocking call + } + t.sendState.setClosed() + t.notifyAllSendRequests() + close(t.sendLoopCloseSignal) + }() + var message remote.Message + for { + var finishSignal chan error + select { + case req := <-t.sendQueue: + finishSignal = req.finishSignal + message = req.message + case <-ctx.Done(): + message = t.newTrailerFrame(ctx.Err()) + } + frameType := codec.MessageFrameType(message) + err := t.writeMessage(message) // message is recycled + if finishSignal != nil { + finishSignal <- err + } + if err != nil || frameType == codec.FrameTypeTrailer { + break + } + } +} + +func (t *ttheaderStream) notifyAllSendRequests() { + for { + select { + case req, ok := <-t.sendQueue: + if ok { + req.finishSignal <- ErrSendClosed + } + default: + return + } + } +} + +func (t *ttheaderStream) writeMessage(message remote.Message) error { + defer message.Recycle() // not referenced after this, including the HeaderFrame + out := t.ext.NewWriteByteBuffer(t.ctx, t.conn, message) + if err := t.codec.Encode(t.ctx, message, out); err != nil { + return err + } + return out.Flush() +} + +func (t *ttheaderStream) loadOutgoingMetadataForGRPC() { + t.grpcMetadata.loadHeadersToSend(t.ctx) +} + +func (t *ttheaderStream) clientReadMetaFrame() error { + if t.recvState.isClosed() { + return t.lastRecvError + } + return t.readUntilTargetFrame(nil, codec.FrameTypeMeta) +} + +func (t *ttheaderStream) serverReadHeaderFrame(opt *remote.ServerOption) (nCtx context.Context, err error) { + var clientHeader remote.Message + defer func() { + if clientHeader != nil { + clientHeader.Recycle() + } + }() + if clientHeader, err = t.readMessageWatchingCtxDone(nil); err != nil { + return nil, err + } + if codec.MessageFrameType(clientHeader) != codec.FrameTypeHeader { + return nil, ErrDirtyFrame + } + + t.recvState.setHeader() + + if nCtx, err = t.serverReadHeaderMeta(opt, clientHeader); err != nil { + return nil, err + } + + t.ctx = opt.TracerCtl.DoStart(nCtx, t.ri) // replace streamCtx since new values are added to the ctx + return t.ctx, err +} + +func (t *ttheaderStream) serverReadHeaderMeta(opt *remote.ServerOption, msg remote.Message) (context.Context, error) { + var err error + nCtx := t.ctx // with cancel + + if nCtx, err = t.frameHandler.ServerReadHeader(nCtx, msg); err != nil { + return nil, fmt.Errorf("frameHandler.ServerReadHeader failed: %w", err) + } + + // StreamingMetaHandler.OnReadStream + for _, handler := range opt.StreamingMetaHandlers { + if nCtx, err = handler.OnReadStream(nCtx); err != nil { + return nil, err + } + } + return nCtx, nil +} diff --git a/pkg/remote/trans/ttheaderstreaming/stream_cleaner.go b/pkg/remote/trans/ttheaderstreaming/stream_cleaner.go new file mode 100644 index 0000000000..7dfc30d821 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/stream_cleaner.go @@ -0,0 +1,61 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "io" + "time" + + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote/remotecli" + "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/transport" +) + +var cleanConnectionTimeout = time.Second + +// SetCleanConnectionTimeout sets the timeout for cleaning the connection before release. +// It's only used for client +func SetCleanConnectionTimeout(t time.Duration) { + cleanConnectionTimeout = t +} + +func init() { + remotecli.RegisterCleaner(transport.TTHeader, &remotecli.StreamCleaner{ + Async: true, + Timeout: cleanConnectionTimeout, + Clean: streamCleaner, + }) +} + +// streamCleaner discards remaining frames of the stream by reading until TrailerFrame (or an error), +// and returns the last RPC error (if any). +func streamCleaner(st streaming.Stream) error { + st.Trailer() + if err := st.RecvMsg(nil); isRPCError(err) { + return err + } + return nil +} + +func isRPCError(err error) bool { + if err == nil || err == io.EOF { + return false + } + _, isBizStatusError := err.(kerrors.BizStatusErrorIface) + return !isBizStatusError // BizStatusError is not considered an RPC Error +} diff --git a/pkg/remote/trans/ttheaderstreaming/stream_cleaner_test.go b/pkg/remote/trans/ttheaderstreaming/stream_cleaner_test.go new file mode 100644 index 0000000000..f4609fb2cc --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/stream_cleaner_test.go @@ -0,0 +1,72 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" +) + +func TestSetCleanConnectionTimeout(t *testing.T) { + SetCleanConnectionTimeout(time.Second * 2) + test.Assert(t, cleanConnectionTimeout == time.Second*2) + + SetCleanConnectionTimeout(time.Second) + test.Assert(t, cleanConnectionTimeout == time.Second) +} + +func Test_isRPCError(t *testing.T) { + t.Run("nil", func(t *testing.T) { + test.Assert(t, !isRPCError(nil)) + }) + + t.Run("io.EOF", func(t *testing.T) { + test.Assert(t, !isRPCError(io.EOF)) + }) + + t.Run("kerrors.NewBizStatusError", func(t *testing.T) { + test.Assert(t, !isRPCError(kerrors.NewBizStatusError(100, ""))) + }) + + t.Run("other error", func(t *testing.T) { + test.Assert(t, isRPCError(errors.New("test error"))) + }) +} + +func Test_streamCleaner(t *testing.T) { + t.Run("rpc-error", func(t *testing.T) { + err := streamCleaner(&mockStream{ + doRecvMsg: func(msg interface{}) error { + return errors.New("test error") + }, + }) + test.Assert(t, err != nil) + }) + t.Run("non-rpc-error", func(t *testing.T) { + err := streamCleaner(&mockStream{ + doRecvMsg: func(msg interface{}) error { + return io.EOF + }, + }) + test.Assert(t, err == nil) + }) +} diff --git a/pkg/remote/trans/ttheaderstreaming/stream_test.go b/pkg/remote/trans/ttheaderstreaming/stream_test.go new file mode 100644 index 0000000000..df01f1d5b9 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/stream_test.go @@ -0,0 +1,2630 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "io" + "net" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + ktest "github.com/cloudwego/kitex/pkg/protocol/bthrift/test/kitex_gen/test" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + "github.com/cloudwego/kitex/pkg/remote/codec/ttheader" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/remote/transmeta" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/transport" +) + +func mockRPCInfo(args ...interface{}) rpcinfo.RPCInfo { + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8888") + from := rpcinfo.NewEndpointInfo("from", "from", nil, map[string]string{}) + to := rpcinfo.NewEndpointInfo("to", "method", addr, map[string]string{}) + var ivk rpcinfo.Invocation = rpcinfo.NewInvocation("idl-service", "method") + cfg := rpcinfo.NewRPCConfig() + cfg.(rpcinfo.MutableRPCConfig).SetPayloadCodec(serviceinfo.Protobuf) + stat := rpcinfo.NewRPCStats() + for _, obj := range args { + switch t := obj.(type) { + case rpcinfo.RPCConfig: + cfg = t + case rpcinfo.Invocation: + ivk = t + case rpcinfo.EndpointInfo: + if _, exists := t.Tag("from"); exists { + from = t + } else if _, exists := t.Tag("to"); exists { + to = t + } else { + panic("invalid endpoint info for mockRPCInfo") + } + case rpcinfo.RPCStats: + stat = t + } + } + return rpcinfo.NewRPCInfo(from, to, ivk, cfg, stat) +} + +func Test_newTTHeaderStream(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Client) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + } else if writeCnt == 1 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + err := st.closeSend(nil) + test.Assert(t, err == nil, err) + test.Assert(t, writeCnt == 2, writeCnt) +} + +func Test_ttheaderStream_SetHeader(t *testing.T) { + t.Run("client", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo != nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + _ = st.SetHeader(md) + test.Assert(t, false, "should panic") + }) + + t.Run("server:header-sent", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo == nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + + err := st.SendHeader(nil) + test.Assert(t, err == nil, err) + + err = st.SetHeader(md) + test.Assert(t, errors.Is(err, ErrIllegalHeaderWrite), err) + }) + + t.Run("server:closed", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo == nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + + st.sendState.setClosed() + + err := st.SetHeader(md) + test.Assert(t, errors.Is(err, ErrIllegalHeaderWrite), err) + }) + + t.Run("normal", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo == nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + + err := st.SetHeader(md) + test.Assert(t, err == nil, err) + + test.Assert(t, reflect.DeepEqual(st.grpcMetadata.headersToSend, md), st.grpcMetadata.headersToSend) + }) +} + +func Test_ttheaderStream_SendHeader(t *testing.T) { + t.Run("client", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo != nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCount := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Server) + test.Assert(t, err == nil, err) + if writeCount == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + } + writeCount += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + _ = st.SendHeader(md) + test.Assert(t, false, "should panic") + }) + + t.Run("server:normal", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Server) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + } else if writeCnt == 1 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + err := st.SendHeader(md) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(st.grpcMetadata.headersToSend, md), st.grpcMetadata.headersToSend) + + _ = st.closeSend(nil) // avoid goroutine leak + test.Assert(t, writeCnt == 2, writeCnt) + }) +} + +func Test_ttheaderStream_SetTrailer(t *testing.T) { + t.Run("client", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo != nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + st.SetTrailer(md) + test.Assert(t, false, "should panic") + }) + + t.Run("server:normal", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{ + "key": []string{"v1", "v2"}, + } + st.SetTrailer(md) + test.Assert(t, reflect.DeepEqual(st.grpcMetadata.trailerToSend, md), st.grpcMetadata.headersToSend) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("server:empty-md", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := metadata.MD{} + st.SetTrailer(md) + test.Assert(t, len(st.grpcMetadata.trailerToSend) == 0, st.grpcMetadata.trailerToSend) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("server:closed", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + _ = st.closeSend(nil) // avoid goroutine leak + + md := metadata.MD{"key": []string{"k1", "k2"}} + st.SetTrailer(md) + test.Assert(t, len(st.grpcMetadata.trailerToSend) == 0, st.grpcMetadata.trailerToSend) + }) +} + +func Test_ttheaderStream_Header(t *testing.T) { + t.Run("server:panic", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo != nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + _, _ = st.Header() + test.Assert(t, false, "should panic") + }) + + t.Run("client:received", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + st.grpcMetadata.headersReceived = metadata.MD{ + "key": []string{"v1", "v2"}, + } + + st.recvState.setHeader() + + md, err := st.Header() + test.Assert(t, err == nil, err) + test.Assert(t, len(md["key"]) == 2, md) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("client:not-received,read-header", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + ri := mockRPCInfo() + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"header":["v1","v2"]}` + buf, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md, err := st.Header() + test.Assert(t, err == nil, err) + test.Assert(t, len(md["header"]) == 2, md) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("client:not-received,read-header-err", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + readErr := errors.New("read header error") + conn := &mockConn{ + read: func(b []byte) (int, error) { + return 0, readErr + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + _, err := st.Header() + test.Assert(t, errors.Is(err, readErr), err) + + _ = st.closeSend(nil) // avoid goroutine leak + }) +} + +func Test_ttheaderStream_Trailer(t *testing.T) { + t.Run("server:panic", func(t *testing.T) { + var st *ttheaderStream + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo != nil, panicInfo) + _ = st.closeSend(nil) // avoid goroutine leak + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + _ = st.Trailer() + test.Assert(t, false, "should panic") + }) + + t.Run("client:trailer-received", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + st.recvState.setTrailer() + st.grpcMetadata.trailerReceived = metadata.MD{ + "trailer": []string{"v1", "v2"}, + } + + md := st.Trailer() + test.Assert(t, len(md["trailer"]) == 2, md) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("client:not-received,read-trailer", func(t *testing.T) { + var st *ttheaderStream + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + ri := mockRPCInfo() + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + + msg = remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"trailer":["v1","v2"]}` + bufTrailer, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return append(bufHeader, bufTrailer...) + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st = newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + test.Assert(t, rpcinfo.GetRPCInfo(st.Context()) != nil) + + md := st.Trailer() + test.Assert(t, len(md["trailer"]) == 2, md) + + _ = st.closeSend(nil) // avoid goroutine leak + }) +} + +func Test_ttheaderStream_RecvMsg(t *testing.T) { + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + req := &ktest.Local{L: 1} + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + conn := &mockConn{ + readBuf: func() []byte { + cfg := rpcinfo.NewRPCConfig() + cfg.(rpcinfo.MutableRPCConfig).SetPayloadCodec(serviceinfo.Thrift) + ri := mockRPCInfo(cfg) + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + + msg = remote.NewMessage(req, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + bufData, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return append(bufHeader, bufData...) + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + gotReq := &ktest.Local{} + err := st.RecvMsg(gotReq) + test.Assert(t, err == nil, err) + test.Assert(t, reflect.DeepEqual(req, gotReq), req, gotReq) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("closed", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + st.closeRecv(io.EOF) + + err := st.RecvMsg(nil) + test.Assert(t, err == io.EOF, err) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("io.EOF,BizStatusError", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + conn := &mockConn{ + readBuf: func() []byte { + ri := mockRPCInfo() + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + + msg = remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + bufTrailer, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return append(bufHeader, bufTrailer...) + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + ri.Invocation().(rpcinfo.InvocationSetter).SetBizStatusErr(kerrors.NewBizStatusError(1, "biz err")) + + err := st.RecvMsg(nil) + test.Assert(t, err == nil, err) + + _ = st.closeSend(nil) // avoid goroutine leak + }) + + t.Run("read-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + conn := &mockConn{ + readBuf: []byte{}, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + err := st.RecvMsg(nil) + test.Assert(t, err != nil, err) + + _ = st.closeSend(nil) // avoid goroutine leak + }) +} + +func Test_ttheaderStream_SendMsg(t *testing.T) { + t.Run("normal", func(t *testing.T) { + cfg := rpcinfo.NewRPCConfig() + cfg.(rpcinfo.MutableRPCConfig).SetPayloadCodec(serviceinfo.Thrift) + ri := mockRPCInfo(cfg) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + var writeBuf []byte + conn := &mockConn{ + write: func(b []byte) (int, error) { + writeBuf = append(writeBuf, b...) + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + req := &ktest.Local{L: 1} + err := st.SendMsg(req) + test.Assert(t, err == nil, err) + + _ = st.closeSend(nil) // avoid goroutine leak + + in := remote.NewReaderBuffer(writeBuf) + msg, err := decodeMessageFromBuffer(context.Background(), in, nil, remote.Client) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + + data := &ktest.Local{} + msg, err = decodeMessageFromBuffer(context.Background(), in, data, remote.Client) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeData, codec.MessageFrameType(msg)) + test.Assert(t, reflect.DeepEqual(req, data), req, data) + + msg, err = decodeMessageFromBuffer(context.Background(), in, nil, remote.Client) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + }) + + t.Run("send-closed", func(t *testing.T) { + cfg := rpcinfo.NewRPCConfig() + cfg.(rpcinfo.MutableRPCConfig).SetPayloadCodec(serviceinfo.Thrift) + ri := mockRPCInfo(cfg) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + err := st.closeSend(nil) + test.Assert(t, err == nil, err) + + err = st.SendMsg(nil) + test.Assert(t, errors.Is(err, ErrSendClosed), err) + }) + + t.Run("send-header-err", func(t *testing.T) { + cfg := rpcinfo.NewRPCConfig() + cfg.(rpcinfo.MutableRPCConfig).SetPayloadCodec(serviceinfo.Thrift) + ri := mockRPCInfo(cfg) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + write: func(b []byte) (int, error) { + return -1, io.ErrClosedPipe + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + err := st.SendMsg(nil) + test.Assert(t, err != nil, err) + + _ = st.closeSend(nil) + }) +} + +func Test_ttheaderStream_Close(t *testing.T) { + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + var writeBuf []byte + conn := &mockConn{ + write: func(b []byte) (int, error) { + writeBuf = append(writeBuf, b...) + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + _ = st.Close() // avoid goroutine leak + + in := remote.NewReaderBuffer(writeBuf) + msg, err := decodeMessageFromBuffer(context.Background(), in, nil, remote.Client) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + + msg, err = decodeMessageFromBuffer(context.Background(), in, nil, remote.Client) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + }) +} + +func Test_ttheaderStream_waitSendLoopClose(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + var writeBuf []byte + conn := &mockConn{ + write: func(b []byte) (int, error) { + writeBuf = append(writeBuf, b...) + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() + + finished := make(chan struct{}) + go func() { + st.waitSendLoopClose() + finished <- struct{}{} + }() + + close(st.sendLoopCloseSignal) + select { + case <-finished: + case <-time.NewTimer(time.Millisecond * 100).C: + t.Error("timeout") + } +} + +func Test_ttheaderStream_sendHeaderFrame(t *testing.T) { + t.Run("already-sent", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Server) + test.Assert(t, err == nil, err) + // only TrailerFrame + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + st.sendState.setHeader() + + err := st.sendHeaderFrame(nil) + test.Assert(t, err == nil, err) + + _ = st.Close() // avoid goroutine leak + }) + + t.Run("client-header", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Server) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + test.Assert(t, msg.TransInfo().TransStrInfo()["key"] == "value", msg.TransInfo().TransStrInfo()) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransStrInfo()["key"] = "value" + + err := st.sendHeaderFrame(msg) + test.Assert(t, err == nil, err) + + _ = st.Close() // avoid goroutine leak + }) + + t.Run("server-header", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Server) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + test.Assert(t, msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] != "", msg.TransInfo().TransStrInfo()) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + ctx = metadata.AppendToOutgoingContext(ctx, "key", "value") + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + st.loadOutgoingMetadataForGRPC() + + err := st.sendHeaderFrame(nil) + test.Assert(t, err == nil, err) + + _ = st.Close() // avoid goroutine leak + }) +} + +func Test_ttheaderStream_sendTrailerFrame(t *testing.T) { + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Server) + test.Assert(t, err == nil, err) + switch writeCnt { + case 0: + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + case 1: + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + default: + test.Assert(t, false, "should not be here") + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + err := st.sendTrailerFrame(nil) + test.Assert(t, err == nil, err) + + err = st.sendTrailerFrame(nil) + test.Assert(t, err == nil, err) + + _ = st.Close() // avoid goroutine leak + }) + + t.Run("send-header-fail", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + write: func(b []byte) (int, error) { + return 0, errors.New("write error") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + err := st.sendTrailerFrame(nil) + test.Assert(t, err != nil, err) + + _ = st.Close() // avoid goroutine leak + }) +} + +func Test_ttheaderStream_closeRecv(t *testing.T) { + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + write: func(b []byte) (int, error) { + return 0, errors.New("write error") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + st.closeRecv(nil) + test.Assert(t, st.lastRecvError == io.EOF, st.lastRecvError) + + st.closeRecv(io.ErrClosedPipe) // close before, no use + test.Assert(t, st.lastRecvError == io.EOF, st.lastRecvError) + + _ = st.Close() // avoid goroutine leak + }) + + t.Run("closed-with-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + write: func(b []byte) (int, error) { + return 0, errors.New("write error") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + st.closeRecv(io.ErrClosedPipe) + + test.Assert(t, st.lastRecvError == io.ErrClosedPipe, st.lastRecvError) + + _ = st.Close() // avoid goroutine leak + }) +} + +func Test_ttheaderStream_newMessageToSend(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + write: func(b []byte) (int, error) { + return 0, errors.New("write error") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + + data := ktest.Local{L: 1} + msg := st.newMessageToSend(data, codec.FrameTypeData) + + test.Assert(t, msg.RPCRole() == remote.Client, msg.RPCRole()) + test.Assert(t, msg.MessageType() == remote.Stream, msg.MessageType()) + test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf, msg.ProtocolInfo().CodecType) + test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeader, msg.ProtocolInfo().TransProto) + test.Assert(t, msg.Tags()[codec.HeaderFlagsKey] == codec.HeaderFlagsStreaming, msg.Tags()) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeData, codec.MessageFrameType(msg)) + + _ = st.Close() // avoid goroutine leak +} + +func Test_ttheaderStream_readMessageWatchingCtxDone(t *testing.T) { + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + ri := mockRPCInfo() + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return bufHeader + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + msg, err := st.readMessageWatchingCtxDone(nil) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + }) + + t.Run("read-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + _, err := st.readMessageWatchingCtxDone(nil) + test.Assert(t, err != nil, err) + }) + + t.Run("timeout", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + read: func(b []byte) (int, error) { + time.Sleep(time.Millisecond * 20) + return 0, io.EOF + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) + defer cancel() + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + _, err := st.readMessageWatchingCtxDone(nil) + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(st.lastRecvError, context.DeadlineExceeded), st.lastRecvError) + }) +} + +func Test_ttheaderStream_readMessage(t *testing.T) { + t.Run("recv-closed", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + _ = st.closeSend(nil) + + _, err := st.readMessage(nil) + test.Assert(t, errors.Is(err, ErrSendClosed), err) + }) + + t.Run("read-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg, err := st.readMessage(nil) + test.Assert(t, err != nil, err) + test.Assert(t, msg == nil, msg) + }) + + t.Run("read-success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return bufHeader + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg, err := st.readMessage(nil) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + }) + + t.Run("read-panic", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + read: func(b []byte) (int, error) { + panic("test panic") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + _, err := st.readMessage(nil) + test.Assert(t, err != nil, err) + test.Assert(t, strings.Contains(err.Error(), "test panic"), err.Error()) + }) +} + +func Test_ttheaderStream_parseMetaFrame(t *testing.T) { + t.Run("already-received", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + read: func(b []byte) (int, error) { + panic("test panic") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + st.recvState.setMeta() + + err := st.parseMetaFrame(nil) + test.Assert(t, strings.Contains(err.Error(), "unexpected meta frame"), err.Error()) + }) + + t.Run("server-got-meta-frame:should-panic", func(t *testing.T) { + defer func() { + panicInfo := recover() + test.Assert(t, panicInfo != nil, panicInfo) + }() + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + read: func(b []byte) (int, error) { + panic("test panic") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + _ = st.parseMetaFrame(msg) + t.Error("should panic") + }) + + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + read: func(b []byte) (int, error) { + panic("test panic") + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + called := false + handler := &mockFrameMetaHandler{ + readMeta: func(ctx context.Context, msg remote.Message) error { + called = true + return nil + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + err := st.parseMetaFrame(msg) + test.Assert(t, err == nil, err) + test.Assert(t, called) + }) +} + +func Test_ttheaderStream_parseHeaderFrame(t *testing.T) { + t.Run("client:success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + called := false + handler := &mockFrameMetaHandler{ + clientReadHeader: func(ctx context.Context, msg remote.Message) error { + called = true + return nil + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"key":["v1","v2"]}` + err := st.parseHeaderFrame(msg) + test.Assert(t, err == nil, err) + test.Assert(t, called) + test.Assert(t, len(st.grpcMetadata.headersReceived["key"]) == 2, st.grpcMetadata.headersReceived) + test.Assert(t, st.recvState.hasHeader()) + }) + + t.Run("client:read-header-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + readErr := errors.New("read header error") + handler := &mockFrameMetaHandler{ + clientReadHeader: func(ctx context.Context, msg remote.Message) error { + return readErr + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"key":["v1","v2"]}` + err := st.parseHeaderFrame(msg) + test.Assert(t, errors.Is(err, readErr), err) + test.Assert(t, len(st.grpcMetadata.headersReceived) == 0, st.grpcMetadata.headersReceived) + test.Assert(t, st.recvState.hasHeader()) + }) + + t.Run("server", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + called := false + handler := &mockFrameMetaHandler{ + readMeta: func(ctx context.Context, msg remote.Message) error { + called = true + return nil + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + err := st.parseHeaderFrame(msg) + test.Assert(t, err == nil, err) + test.Assert(t, !called, called) // it's called in server_handler.go + test.Assert(t, st.recvState.hasHeader()) + }) +} + +func Test_ttheaderStream_parseDataFrame(t *testing.T) { + t.Run("no-header", func(t *testing.T) { + st := &ttheaderStream{} + err := st.parseDataFrame(nil) + test.Assert(t, err != nil, err) + }) + t.Run("has-header", func(t *testing.T) { + st := &ttheaderStream{} + st.recvState.setHeader() + err := st.parseDataFrame(nil) + test.Assert(t, err == nil, err) + }) +} + +func Test_ttheaderStream_parseTrailerFrame(t *testing.T) { + t.Run("client:no-header", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{} + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + err := st.parseTrailerFrame(msg) + test.Assert(t, err != nil, err) + test.Assert(t, st.recvState.hasTrailer()) + }) + + t.Run("client:ReadTrailer-TransErr", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + transErr := remote.NewTransErrorWithMsg(1, "mock err") + handler := &mockFrameMetaHandler{ + readTrailer: func(ctx context.Context, msg remote.Message) error { + return transErr + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + err := st.parseTrailerFrame(msg) + test.Assert(t, errors.Is(err, transErr), err) + test.Assert(t, st.recvState.hasTrailer()) + }) + + t.Run("client:ReadTrailer-other-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + readErr := errors.New("read error") + handler := &mockFrameMetaHandler{ + readTrailer: func(ctx context.Context, msg remote.Message) error { + return readErr + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + err := st.parseTrailerFrame(msg) + var transError *remote.TransError + isTransErr := errors.As(err, &transError) + test.Assert(t, !isTransErr && transError == nil, err) + test.Assert(t, strings.Contains(err.Error(), readErr.Error()), err) + test.Assert(t, st.recvState.hasTrailer()) + }) + + t.Run("client:parse-grpc-trailer-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + readTrailer: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `invalid trailer` + err := st.parseTrailerFrame(msg) + test.Assert(t, err != nil, err) + test.Assert(t, len(st.grpcMetadata.trailerReceived) == 0, st.grpcMetadata.trailerReceived) + test.Assert(t, errors.Is(err, st.lastRecvError), st.lastRecvError) + }) + + t.Run("client:success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + readTrailer: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"key":["v1","v2"]}` + err := st.parseTrailerFrame(msg) + test.Assert(t, err == io.EOF, err) + test.Assert(t, len(st.grpcMetadata.trailerReceived) == 1, st.grpcMetadata.trailerReceived) + test.Assert(t, st.lastRecvError == io.EOF, st.lastRecvError) + }) + + t.Run("server:success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + readTrailer: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData] = `{"key":["v1","v2"]}` + err := st.parseTrailerFrame(msg) + test.Assert(t, err == io.EOF, err) + // server: not grpc-trailer from client + test.Assert(t, len(st.grpcMetadata.trailerReceived) == 0, st.grpcMetadata.trailerReceived) + test.Assert(t, st.lastRecvError == io.EOF, st.lastRecvError) + }) +} + +func Test_ttheaderStream_readAndParseMessage(t *testing.T) { + t.Run("read-msg-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{} + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, err != nil, err) + test.Assert(t, ft == codec.FrameTypeInvalid, ft) + }) + + t.Run("parse-meta-frame-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeMeta + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + mockErr := errors.New("mock error") + handler := &mockFrameMetaHandler{ + readMeta: func(ctx context.Context, msg remote.Message) error { + return mockErr + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, errors.Is(err, mockErr), err) + test.Assert(t, ft == codec.FrameTypeMeta, ft) + }) + + t.Run("parse-meta-frame-success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeMeta + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + readMeta: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, err == nil, err) + test.Assert(t, ft == codec.FrameTypeMeta, ft) + }) + + t.Run("parse-header-frame-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + mockErr := errors.New("mock error") + handler := &mockFrameMetaHandler{ + clientReadHeader: func(ctx context.Context, msg remote.Message) error { + return mockErr + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, errors.Is(err, mockErr), err) + test.Assert(t, ft == codec.FrameTypeHeader, ft) + }) + + t.Run("parse-header-frame-success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + clientReadHeader: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, err == nil, err) + test.Assert(t, ft == codec.FrameTypeHeader, ft) + }) + + t.Run("parse-data-frame-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + conn := &mockConn{ + readBuf: func() []byte { + data := &ktest.Local{L: 1} + msg := remote.NewMessage(data, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + buf, err := encodeMessage(ctx, msg) + test.Assert(t, err == nil, err) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{} + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, err != nil, err) // read data frame before header frame + test.Assert(t, ft == codec.FrameTypeData, ft) + }) + + t.Run("parse-data-frame-success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + conn := &mockConn{ + readBuf: func() []byte { + data := &ktest.Local{L: 1} + msg := remote.NewMessage(data, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + buf, err := encodeMessage(ctx, msg) + test.Assert(t, err == nil, err) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + clientReadHeader: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, err == nil, err) + test.Assert(t, ft == codec.FrameTypeData, ft) + }) + + t.Run("parse-trailer-frame-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + buf, err := encodeMessage(ctx, msg) + test.Assert(t, err == nil, err) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + mockErr := errors.New("mock error") + handler := &mockFrameMetaHandler{ + readTrailer: func(ctx context.Context, msg remote.Message) error { + return mockErr + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, errors.Is(err, mockErr), err) + test.Assert(t, ft == codec.FrameTypeTrailer, ft) + }) + + t.Run("parse-trailer-frame-success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + buf, err := encodeMessage(ctx, msg) + test.Assert(t, err == nil, err) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + clientReadHeader: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + ft, err := st.readAndParseMessage(nil) + test.Assert(t, err == io.EOF, err) + test.Assert(t, ft == codec.FrameTypeTrailer, ft) + }) + + t.Run("unknown-frame-type", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = "XXX" + buf, err := encodeMessage(ctx, msg) + test.Assert(t, err == nil, err) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{ + clientReadHeader: func(ctx context.Context, msg remote.Message) error { + return nil + }, + } + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + st.recvState.setHeader() + + _, err := st.readAndParseMessage(nil) + test.Assert(t, errors.Is(err, ErrInvalidFrame), err) + }) +} + +func Test_ttheaderStream_readUntilTargetFrame(t *testing.T) { + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + ri := mockRPCInfo() + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return bufHeader + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + err := st.readUntilTargetFrame(nil, codec.FrameTypeHeader) + test.Assert(t, err == nil, err) + }) + + t.Run("read-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + err := st.readUntilTargetFrame(nil, codec.FrameTypeHeader) + test.Assert(t, err != nil, err) + }) + + t.Run("timeout", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + read: func(b []byte) (int, error) { + time.Sleep(time.Millisecond * 20) + return 0, io.EOF + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) + defer cancel() + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + err := st.readUntilTargetFrame(nil, codec.FrameTypeHeader) + test.Assert(t, err != nil, err) + test.Assert(t, errors.Is(st.lastRecvError, context.DeadlineExceeded), st.lastRecvError) + }) +} + +func Test_ttheaderStream_readUntilTargetFrameSynchronously(t *testing.T) { + t.Run("read-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + err := st.readUntilTargetFrameSynchronously(nil, codec.FrameTypeHeader) + test.Assert(t, err != nil, err) + }) + + t.Run("read-target-frame", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + ri := mockRPCInfo() + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return bufHeader + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + err := st.readUntilTargetFrameSynchronously(nil, codec.FrameTypeHeader) + test.Assert(t, err == nil, err) + }) + + t.Run("2nd-to-be-target-frame", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{ + readBuf: func() []byte { + ri := mockRPCInfo() + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + bufHeader, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + msg = remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + bufTrailer, err := encodeMessage(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return append(bufHeader, bufTrailer...) + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + err := st.readUntilTargetFrameSynchronously(nil, codec.FrameTypeTrailer) + test.Assert(t, err == io.EOF, err) + }) +} + +func Test_ttheaderStream_putMessageToSendQueue(t *testing.T) { + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Client) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + err := st.putMessageToSendQueue(msg) + test.Assert(t, err == nil, err) + test.Assert(t, writeCnt > 0) + }) + + t.Run("ctx-done", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) + defer cancel() + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Client) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + time.Sleep(time.Millisecond * 20) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + err := st.putMessageToSendQueue(msg) + test.Assert(t, errors.Is(err, context.DeadlineExceeded), err) + }) + + t.Run("sendLoopCloseSignal-closed", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Client) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + time.Sleep(time.Millisecond * 50) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + _ = st.Close() // finish sendLoop and close sendLoopCloseSignal + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + err := st.putMessageToSendQueue(msg) + test.Assert(t, errors.Is(err, ErrSendClosed), err) + }) +} + +func Test_sendRequest_waitSendFinishedSignal(t *testing.T) { + t.Run("normal", func(t *testing.T) { + st := &ttheaderStream{ + ctx: context.Background(), + sendLoopCloseSignal: make(chan struct{}), + } + ch := make(chan error, 1) + ch <- nil + err := st.waitSendFinishedSignal(ch) + test.Assert(t, err == nil, err) + }) + t.Run("ctx-done", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + st := &ttheaderStream{ + ctx: ctx, + sendLoopCloseSignal: make(chan struct{}), + } + cancel() + ch := make(chan error, 1) + err := st.waitSendFinishedSignal(ch) + test.Assert(t, errors.Is(err, context.Canceled), err) + }) + t.Run("sendLoopCloseSignal", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + st := &ttheaderStream{ + ctx: ctx, + sendLoopCloseSignal: make(chan struct{}), + } + defer cancel() + close(st.sendLoopCloseSignal) + ch := make(chan error, 1) + err := st.waitSendFinishedSignal(ch) + test.Assert(t, errors.Is(err, ErrSendClosed), err) + }) +} + +func Test_sendRequest_waitFinishSignal(t *testing.T) { + t.Run("normal", func(t *testing.T) { + sendReq := sendRequest{ + finishSignal: make(chan error, 1), + } + sendReq.finishSignal <- nil + err := sendReq.waitFinishSignal(context.Background()) + test.Assert(t, err == nil, err) + }) + t.Run("canceled", func(t *testing.T) { + sendReq := sendRequest{ + finishSignal: make(chan error, 1), + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := sendReq.waitFinishSignal(ctx) + test.Assert(t, errors.Is(err, context.Canceled), err) + }) +} + +func Test_ttheaderStream_closeSend(t *testing.T) { + t.Run("closed-before", func(t *testing.T) { + st := &ttheaderStream{} + st.sendState.setClosed() + err := st.closeSend(errors.New("XXX")) + test.Assert(t, err == nil, err) + }) + + t.Run("client", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Client) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + } else if writeCnt == 1 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + test.Assert(t, msg.TransInfo().TransIntInfo()[transmeta.TransCode] != "", msg.TransInfo()) + test.Assert(t, msg.TransInfo().TransIntInfo()[transmeta.TransMessage] != "", msg.TransInfo()) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + err := st.closeSend(errors.New("XXX")) + test.Assert(t, err == nil, err) + test.Assert(t, writeCnt == 2) + _, active := <-st.sendLoopCloseSignal + test.Assert(t, !active, "sendLoopCloseSignal should be closed") + }) + + t.Run("server", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + writeCnt := 0 + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(context.Background(), b, nil, remote.Client) + test.Assert(t, err == nil, err) + if writeCnt == 0 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, codec.MessageFrameType(msg)) + } else if writeCnt == 1 { + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, codec.MessageFrameType(msg)) + test.Assert(t, msg.TransInfo().TransIntInfo()[transmeta.TransCode] != "", msg.TransInfo()) + test.Assert(t, msg.TransInfo().TransIntInfo()[transmeta.TransMessage] != "", msg.TransInfo()) + } + writeCnt += 1 + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + err := st.closeSend(errors.New("XXX")) + test.Assert(t, err == nil, err) + test.Assert(t, writeCnt == 2) + _, active := <-st.sendLoopCloseSignal + test.Assert(t, !active, "sendLoopCloseSignal should be closed") + test.Assert(t, st.recvState.isClosed(), st.recvState) + test.Assert(t, errors.Is(st.ctx.Err(), context.Canceled), st.ctx.Err()) + }) +} + +func Test_ttheaderStream_newTrailerFrame(t *testing.T) { + t.Run("write-trailer-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + mockErr := errors.New("XXX") + handler := &mockFrameMetaHandler{ + writeTrailer: func(ctx context.Context, message remote.Message) error { + return mockErr + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := st.newTrailerFrame(nil) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, msg.TransInfo()) + }) + + t.Run("normal", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{} + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + st.SetTrailer(metadata.MD{"key": []string{"v1", "v2"}}) + + msg := st.newTrailerFrame(nil) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, msg.TransInfo()) + test.Assert(t, len(msg.TransInfo().TransStrInfo()[codec.StrKeyMetaData]) != 0, msg.TransInfo()) + }) +} + +func Test_ttheaderStream_sendLoop(t *testing.T) { + t.Run("normal:send-header", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{} + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + sendReq := sendRequest{ + message: msg, + finishSignal: make(chan error, 1), + } + st.sendQueue <- sendReq + err := sendReq.waitFinishSignal(st.Context()) + test.Assert(t, err == nil, err) + test.Assert(t, !st.sendState.isClosed(), st.recvState) + }) + + t.Run("normal:send-trailer-with-extra-message", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + sendFinishSignal := make(chan struct{}) + conn := &mockConn{ + write: func(b []byte) (int, error) { + <-sendFinishSignal // make sure all sendRequests are put into the sendQueue + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := &mockFrameMetaHandler{} + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + var requests []sendRequest + for i := 0; i < 10; i++ { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + sendReq := sendRequest{ + message: msg, + finishSignal: make(chan error, 1), + } + requests = append(requests, sendReq) + st.sendQueue <- sendReq + } + close(sendFinishSignal) + for i := 0; i < 10; i++ { + err := requests[i].waitFinishSignal(st.Context()) + if i == 0 { + test.Assert(t, err == nil, err) + } else { + test.Assert(t, errors.Is(err, ErrSendClosed), err) + } + } + test.Assert(t, st.sendState.isClosed(), st.recvState) + _, active := <-st.sendLoopCloseSignal + test.Assert(t, !active, "sendLoopCloseSignal should be closed") + }) + + t.Run("ctx-done", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(ctx, b, nil, remote.Client) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeTrailer, msg.TransInfo()) + test.Assert(t, msg.TransInfo().TransIntInfo()[transmeta.TransMessage] == context.Canceled.Error(), + msg.TransInfo().TransIntInfo()) + return len(b), nil + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + cancel() + st.waitSendLoopClose() + test.Assert(t, st.sendState.isClosed(), st.recvState) + }) + + t.Run("write-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{ + write: func(b []byte) (int, error) { + return 0, io.ErrClosedPipe + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + sendReq := sendRequest{ + message: msg, + finishSignal: make(chan error, 1), + } + st.sendQueue <- sendReq + err := sendReq.waitFinishSignal(st.Context()) + test.Assert(t, err != nil, err) + test.Assert(t, st.sendState.isClosed(), st.recvState) + }) +} + +func Test_ttheaderStream_notifyAllSendRequests(t *testing.T) { + st := &ttheaderStream{ + sendQueue: make(chan sendRequest, 10), + } + n := 5 + wg := sync.WaitGroup{} + wg.Add(n) + for i := 0; i < n; i++ { + req := sendRequest{finishSignal: make(chan error, 1)} + st.sendQueue <- req + go func() { + defer wg.Done() + _ = req.waitFinishSignal(context.Background()) + }() + } + st.notifyAllSendRequests() + wg.Wait() + // if everything works well, wg.Wait() will not block +} + +func Test_ttheaderStream_writeMessage(t *testing.T) { + t.Run("success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + mockErr := errors.New("mock error") + conn := &mockConn{ + write: func(b []byte) (int, error) { + msg, err := decodeMessage(ctx, b, nil, remote.Client) + test.Assert(t, err == nil, err) + test.Assert(t, codec.MessageFrameType(msg) == codec.FrameTypeHeader, msg.TransInfo()) + return 0, mockErr + }, + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Server) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + err := st.writeMessage(msg) + test.Assert(t, errors.Is(err, mockErr), err) + }) + + t.Run("encode-err", func(t *testing.T) { + ri := mockRPCInfo() + remote.PutPayloadCode(serviceinfo.Thrift, thrift.NewThriftCodec()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + data := &sendRequest{} + msg := remote.NewMessage(data, nil, ri, remote.Stream, remote.Server) + msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeData + err := st.writeMessage(msg) + test.Assert(t, err != nil, err) + test.Assert(t, strings.Contains(err.Error(), "thriftCodec"), err) + }) +} + +func Test_ttheaderStream_clientReadMetaFrame(t *testing.T) { + t.Run("recv-closed", func(t *testing.T) { + st := &ttheaderStream{} + st.recvState.setClosed() + st.lastRecvError = io.EOF + err := st.clientReadMetaFrame() + test.Assert(t, err == io.EOF, err) + }) + t.Run("success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeMeta + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getClientFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Client, ext, handler) + defer st.Close() // avoid goroutine leak + + err := st.clientReadMetaFrame() + test.Assert(t, err == nil, err) + }) +} + +func Test_ttheaderStream_serverReadHeaderMeta(t *testing.T) { + t.Run("success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{ + Option: remote.Option{ + StreamingMetaHandlers: []remote.StreamingMetaHandler{ + remote.NewCustomMetaHandler(remote.WithOnReadStream( + func(ctx context.Context) (context.Context, error) { + return context.WithValue(ctx, "key", "value"), nil + }, + )).(remote.StreamingMetaHandler), + }, + }, + } + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = "ToMethod" + + ctx, err := st.serverReadHeaderMeta(opt, msg) + test.Assert(t, err == nil, err) + test.Assert(t, msg.RPCInfo().Invocation().MethodName() == "ToMethod") + test.Assert(t, ctx.Value("key") == "value", ctx.Value("key")) + }) + + t.Run("frameHandler.ServerReadHeader:err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{} + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + + _, err := st.serverReadHeaderMeta(opt, msg) + test.Assert(t, err != nil, err) + }) + + t.Run("OnReadStream-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{ + Option: remote.Option{ + StreamingMetaHandlers: []remote.StreamingMetaHandler{ + remote.NewCustomMetaHandler(remote.WithOnReadStream( + func(ctx context.Context) (context.Context, error) { + return nil, errors.New("err") + }, + )).(remote.StreamingMetaHandler), + }, + }, + } + + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = "ToMethod" + + _, err := st.serverReadHeaderMeta(opt, msg) + test.Assert(t, err != nil, err) + }) +} + +func Test_ttheaderStream_serverReadHeaderFrame(t *testing.T) { + t.Run("success", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = "ToMethod" + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{ + TracerCtl: &rpcinfo.TraceController{}, + } + opt.TracerCtl.Append(&mockTracer{ + start: func(ctx context.Context) context.Context { + return context.WithValue(ctx, "key", "value") + }, + }) + + ctx, err := st.serverReadHeaderFrame(opt) + test.Assert(t, err == nil, err) + test.Assert(t, ctx.Value("key") == "value", ctx.Value("key")) + test.Assert(t, st.recvState.hasHeader(), st.recvState) + }) + + t.Run("read-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{ + TracerCtl: &rpcinfo.TraceController{}, + } + + _, err := st.serverReadHeaderFrame(opt) + test.Assert(t, err != nil, err) + }) + + t.Run("read-err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{} + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{ + TracerCtl: &rpcinfo.TraceController{}, + } + + _, err := st.serverReadHeaderFrame(opt) + test.Assert(t, err != nil, err) + }) + + t.Run("dirty-frame", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeTrailer + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + handler := getServerFrameMetaHandler(nil) + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{ + TracerCtl: &rpcinfo.TraceController{}, + } + + ctx, err := st.serverReadHeaderFrame(opt) + test.Assert(t, errors.Is(err, ErrDirtyFrame), err) + }) + + t.Run("serverReadHeaderMeta:err", func(t *testing.T) { + ri := mockRPCInfo() + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + conn := &mockConn{ + readBuf: func() []byte { + msg := remote.NewMessage(nil, nil, ri, remote.Stream, remote.Client) + msg.TransInfo().TransIntInfo()[transmeta.FrameType] = codec.FrameTypeHeader + msg.TransInfo().TransIntInfo()[transmeta.ToMethod] = "ToMethod" + buf, _ := encodeMessage(ctx, msg) + return buf + }(), + } + streamCodec := ttheader.NewStreamCodec() + ext := &mockExtension{} + mockErr := errors.New("mock error") + handler := &mockFrameMetaHandler{ + serverReadHeader: func(ctx context.Context, msg remote.Message) (context.Context, error) { + return nil, mockErr + }, + } + + st := newTTHeaderStream(ctx, conn, ri, streamCodec, remote.Server, ext, handler) + defer st.Close() // avoid goroutine leak + + opt := &remote.ServerOption{ + TracerCtl: &rpcinfo.TraceController{}, + } + opt.TracerCtl.Append(&mockTracer{ + start: func(ctx context.Context) context.Context { + return context.WithValue(ctx, "key", "value") + }, + }) + + ctx, err := st.serverReadHeaderFrame(opt) + test.Assert(t, errors.Is(err, mockErr), err) + }) +} diff --git a/pkg/remote/trans/ttheaderstreaming/util.go b/pkg/remote/trans/ttheaderstreaming/util.go new file mode 100644 index 0000000000..1d36116550 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/util.go @@ -0,0 +1,83 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "net" + + "github.com/bytedance/sonic" + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +type netpollReader interface { + Reader() netpoll.Reader +} + +func getReader(conn net.Conn) netpoll.Reader { + return conn.(netpollReader).Reader() +} + +// coverBizStatusError saves the BizStatusErr (if any) into the RPCInfo and returns nil since it's not an RPC error +// server.invokeHandleEndpoint already processed BizStatusError, so it's only used by unknownMethodHandler +func coverBizStatusError(err error, ri rpcinfo.RPCInfo) error { + var bizErr kerrors.BizStatusErrorIface + if errors.As(err, &bizErr) { + err = nil + if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { + setter.SetBizStatusErr(bizErr) + } + } + return err +} + +// getRemoteAddr returns the remote address of the connection. +func getRemoteAddr(ctx context.Context, conn net.Conn) net.Addr { + addr := conn.RemoteAddr() + if addr != nil && addr.Network() != "unix" { + // unix socket: local sidecar proxy + return addr + } + // likely from service mesh: return the addr read from mesh ttheader + ri := rpcinfo.GetRPCInfo(ctx) + if ri == nil { + return addr + } + if ri.From().Address() != nil { + return ri.From().Address() + } + return addr +} + +// injectGRPCMetadata parses grpc style metadata into ctx, helpful for projects migrating from kitex-grpc +func injectGRPCMetadata(ctx context.Context, strInfo map[string]string) (context.Context, error) { + value, exists := strInfo[codec.StrKeyMetaData] + if !exists { + return ctx, nil + } + md := metadata.MD{} + if err := sonic.Unmarshal([]byte(value), &md); err != nil { + return ctx, err + } + return metadata.NewIncomingContext(ctx, md), nil +} diff --git a/pkg/remote/trans/ttheaderstreaming/util_test.go b/pkg/remote/trans/ttheaderstreaming/util_test.go new file mode 100644 index 0000000000..d4a567ee68 --- /dev/null +++ b/pkg/remote/trans/ttheaderstreaming/util_test.go @@ -0,0 +1,128 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttheaderstreaming + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +func Test_coverBizStatusError(t *testing.T) { + t.Run("nil", func(t *testing.T) { + var err error + ivk := rpcinfo.NewInvocation("svc", "method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + + got := coverBizStatusError(err, ri) + + test.Assert(t, got == nil) + test.Assert(t, ri.Invocation().BizStatusErr() == nil) + }) + t.Run("not-biz-status-error", func(t *testing.T) { + err := errors.New("XXX") + ivk := rpcinfo.NewInvocation("svc", "method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + + got := coverBizStatusError(err, ri) + + test.Assert(t, errors.Is(got, err)) + test.Assert(t, ri.Invocation().BizStatusErr() == nil) + }) + t.Run("biz-status-err", func(t *testing.T) { + err := kerrors.NewBizStatusError(1, "XXX") + ivk := rpcinfo.NewInvocation("svc", "method") + ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil) + + got := coverBizStatusError(err, ri) + + test.Assert(t, got == nil) + test.Assert(t, ri.Invocation().BizStatusErr() != nil) + }) +} + +func Test_getRemoteAddr(t *testing.T) { + tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8888") + unixAddr, _ := net.ResolveUnixAddr("unix", "demo.sock") + t.Run("tcp-addr", func(t *testing.T) { + conn := &mockConn{remoteAddr: tcpAddr} + got := getRemoteAddr(context.Background(), conn) + test.Assert(t, got.String() == "127.0.0.1:8888") + }) + t.Run("unix-addr-without-rpcinfo", func(t *testing.T) { + conn := &mockConn{remoteAddr: unixAddr} + got := getRemoteAddr(context.Background(), conn) + test.Assert(t, got == unixAddr) + }) + t.Run("unix-addr-without-rpcinfo-from-addr", func(t *testing.T) { + conn := &mockConn{remoteAddr: unixAddr} + from := rpcinfo.NewEndpointInfo("svc", "method", nil, nil) + ri := rpcinfo.NewRPCInfo(from, nil, nil, nil, nil) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + got := getRemoteAddr(ctx, conn) + test.Assert(t, got == unixAddr) + }) + t.Run("unix-addr-with-rpcinfo-from-addr", func(t *testing.T) { + conn := &mockConn{remoteAddr: unixAddr} + from := rpcinfo.NewEndpointInfo("svc", "method", tcpAddr, nil) + ri := rpcinfo.NewRPCInfo(from, nil, nil, nil, nil) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + got := getRemoteAddr(ctx, conn) + test.Assert(t, got == tcpAddr) + }) +} + +func Test_injectGRPCMetadata(t *testing.T) { + t.Run("no-grpc-metadata", func(t *testing.T) { + ctx := context.Background() + got, err := injectGRPCMetadata(ctx, map[string]string{}) + test.Assert(t, err == nil) + test.Assert(t, got == ctx) + }) + + t.Run("grpc-metadata", func(t *testing.T) { + ctx := context.Background() + strInfo := map[string]string{ + codec.StrKeyMetaData: `{"k1":["v1","v2"],"k2":["v3","v4"]}`, + } + got, err := injectGRPCMetadata(ctx, strInfo) + test.Assert(t, err == nil) + test.Assert(t, got != ctx) + md, ok := metadata.FromIncomingContext(got) + test.Assert(t, ok) + test.Assert(t, md["k1"][0] == "v1") + test.Assert(t, md["k1"][1] == "v2") + test.Assert(t, md["k2"][0] == "v3") + test.Assert(t, md["k2"][1] == "v4") + }) + + t.Run("invalid-grpc-metadata", func(t *testing.T) { + ctx := context.Background() + strInfo := map[string]string{ + codec.StrKeyMetaData: `invalid`, + } + _, err := injectGRPCMetadata(ctx, strInfo) + test.Assert(t, err != nil) + }) +} diff --git a/pkg/remote/trans_handler.go b/pkg/remote/trans_handler.go index 7d46a6bb7f..a193464972 100644 --- a/pkg/remote/trans_handler.go +++ b/pkg/remote/trans_handler.go @@ -21,6 +21,8 @@ import ( "net" "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streaming" ) // ClientTransHandlerFactory to new TransHandler for client @@ -28,6 +30,11 @@ type ClientTransHandlerFactory interface { NewTransHandler(opt *ClientOption) (ClientTransHandler, error) } +// ClientStreamTransHandlerFactory to new StreamTransHandler for client +type ClientStreamTransHandlerFactory interface { + NewStreamTransHandler(opt *ClientOption) (ClientTransHandler, error) +} + // ServerTransHandlerFactory to new TransHandler for server type ServerTransHandlerFactory interface { NewTransHandler(opt *ServerOption) (ServerTransHandler, error) @@ -54,6 +61,12 @@ type ClientTransHandler interface { TransHandler } +type ClientStreamAllocator interface { + NewStream( + ctx context.Context, svcInfo *serviceinfo.ServiceInfo, conn net.Conn, handler TransReadWriter, + ) (streaming.Stream, error) +} + // ServerTransHandler have some new functions. type ServerTransHandler interface { TransHandler diff --git a/pkg/remote/trans_meta.go b/pkg/remote/trans_meta.go index 00ed7a12e6..a2c56d48cb 100644 --- a/pkg/remote/trans_meta.go +++ b/pkg/remote/trans_meta.go @@ -28,8 +28,28 @@ type MetaHandler interface { // StreamingMetaHandler reads or writes metadata through streaming header(http2 header) type StreamingMetaHandler interface { - // writes metadata before create a stream + // OnConnectStream writes metadata before create a stream OnConnectStream(ctx context.Context) (context.Context, error) - // reads metadata before read first message from stream + // OnReadStream reads metadata before read first message from stream OnReadStream(ctx context.Context) (context.Context, error) } + +// FrameMetaHandler deals with MetaFrame, HeaderFrame and TrailerFrame +// It's used by TTHeader Streaming +type FrameMetaHandler interface { + // ReadMeta is called after a MetaFrame is received by client + // MetaFrame is not necessary and a typical scenario is sent by a proxy which takes over service discovery, + // By sending the MetaFrame before server's HeaderFrame, proxy can inform the client about the real address. + ReadMeta(ctx context.Context, message Message) error + // ClientReadHeader is called after a HeaderFrame is received by client + ClientReadHeader(ctx context.Context, message Message) error + // ServerReadHeader is called after a HeaderFrame is received by server + // The context returned will be used to replace the stream context + ServerReadHeader(ctx context.Context, message Message) (context.Context, error) + // WriteHeader is called before a HeaderFrame is sent by client/server + WriteHeader(ctx context.Context, message Message) error + // ReadTrailer is called after a TrailerFrame is received by client/server + ReadTrailer(ctx context.Context, message Message) error + // WriteTrailer is called before a TrailerFrame is sent by client/server + WriteTrailer(ctx context.Context, message Message) error +} diff --git a/pkg/remote/transmeta/metakey.go b/pkg/remote/transmeta/metakey.go index a2ad43ce3d..7ce41db54c 100644 --- a/pkg/remote/transmeta/metakey.go +++ b/pkg/remote/transmeta/metakey.go @@ -47,6 +47,10 @@ const ( HTTPContentType RawRingHashKey LBType + ClusterShardID + FrameType // ttheader streaming frame type + TransCode + TransMessage ) // key of header transport diff --git a/pkg/rpcinfo/rpcconfig.go b/pkg/rpcinfo/rpcconfig.go index 996c322c3f..c18fac054f 100644 --- a/pkg/rpcinfo/rpcconfig.go +++ b/pkg/rpcinfo/rpcconfig.go @@ -17,6 +17,7 @@ package rpcinfo import ( + "fmt" "sync" "time" @@ -166,7 +167,18 @@ func (r *rpcConfig) TransportProtocol() transport.Protocol { // SetTransportProtocol implements MutableRPCConfig interface. func (r *rpcConfig) SetTransportProtocol(tp transport.Protocol) error { - r.transportProtocol |= tp + if tp&transport.TTHeaderFramed != 0 { + if tp&(^transport.TTHeaderFramed) != 0 { + panic(fmt.Sprintf("invalid transport protocol: %b", tp)) + } + // TTHeader and Framed can be combined for [TTHeader + [FramedSize + Payload]] + r.transportProtocol &= transport.TTHeaderFramed // clear bits except TTHeader | Framed + r.transportProtocol |= tp + } else { + // other transports are mutually exclusive + // it's user's responsibility to set only one transport, not an OR-ed combination of multiple transports + r.transportProtocol = tp + } return nil } diff --git a/pkg/rpcinfo/rpcconfig_test.go b/pkg/rpcinfo/rpcconfig_test.go index d7effd7ec1..5c385f6189 100644 --- a/pkg/rpcinfo/rpcconfig_test.go +++ b/pkg/rpcinfo/rpcconfig_test.go @@ -33,3 +33,56 @@ func TestRPCConfig(t *testing.T) { test.Assert(t, c.IOBufferSize() != 0) test.Assert(t, c.TransportProtocol() == transport.PurePayload) } + +func TestSetTransportProtocol(t *testing.T) { + t.Run("set-ttheader", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + }) + t.Run("set-framed", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.Framed) + test.Assert(t, c.TransportProtocol() == transport.Framed, c.TransportProtocol()) + }) + t.Run("set-ttheader-framed", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeaderFramed) + test.Assert(t, c.TransportProtocol() == transport.TTHeaderFramed, c.TransportProtocol()) + }) + t.Run("set-ttheader-set-framed", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.Framed) + test.Assert(t, c.TransportProtocol() == transport.TTHeaderFramed, c.TransportProtocol()) + }) + t.Run("set-framed-set-ttheader", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.Framed) + test.Assert(t, c.TransportProtocol() == transport.Framed, c.TransportProtocol()) + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeaderFramed, c.TransportProtocol()) + }) + t.Run("set-ttheader-set-grpc", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.GRPC) + test.Assert(t, c.TransportProtocol() == transport.GRPC, c.TransportProtocol()) + }) + t.Run("set-grpc-set-ttheader", func(t *testing.T) { + c := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.GRPC) + test.Assert(t, c.TransportProtocol() == transport.GRPC, c.TransportProtocol()) + rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader) + test.Assert(t, c.TransportProtocol() == transport.TTHeader, c.TransportProtocol()) + }) + t.Run("set-invalid-transport", func(t *testing.T) { + defer func() { + test.Assert(t, recover() != nil) + }() + c := rpcinfo.NewRPCConfig() + _ = rpcinfo.AsMutableRPCConfig(c).SetTransportProtocol(transport.TTHeader | transport.GRPC) + }) +} diff --git a/pkg/serviceinfo/serviceinfo.go b/pkg/serviceinfo/serviceinfo.go index 5036d1cfce..c1d73f9c36 100644 --- a/pkg/serviceinfo/serviceinfo.go +++ b/pkg/serviceinfo/serviceinfo.go @@ -28,6 +28,7 @@ const ( Thrift PayloadCodec = iota Protobuf Hessian2 + NotSpecified PayloadCodec = 255 ) const ( diff --git a/pkg/transmeta/ttheader.go b/pkg/transmeta/ttheader.go index 801ba3f025..49c62a1d66 100644 --- a/pkg/transmeta/ttheader.go +++ b/pkg/transmeta/ttheader.go @@ -89,20 +89,26 @@ func (ch *clientTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Messag transInfo := msg.TransInfo() strInfo := transInfo.TransStrInfo() + if err := ParseBizStatusErrorToRPCInfo(strInfo, ri); err != nil { + return nil, err + } + return ctx, nil +} + +// ParseBizStatusErrorToRPCInfo parse biz status error from strInfo and set it to rpcinfo.Invocation +func ParseBizStatusErrorToRPCInfo(strInfo map[string]string, ri rpcinfo.RPCInfo) (err error) { if code, err := strconv.Atoi(strInfo[bizStatus]); err == nil && code != 0 { if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { + var extra map[string]string if bizExtra := strInfo[bizExtra]; bizExtra != "" { - extra, err := utils.JSONStr2Map(bizExtra) - if err != nil { - return ctx, fmt.Errorf("malformed header info, extra: %s", bizExtra) + if extra, err = utils.JSONStr2Map(bizExtra); err != nil { + return fmt.Errorf("malformed header info, extra: %s", bizExtra) } - setter.SetBizStatusErr(kerrors.NewBizStatusErrorWithExtra(int32(code), strInfo[bizMessage], extra)) - } else { - setter.SetBizStatusErr(kerrors.NewBizStatusError(int32(code), strInfo[bizMessage])) } + setter.SetBizStatusErr(kerrors.NewBizStatusErrorWithExtra(int32(code), strInfo[bizMessage], extra)) } } - return ctx, nil + return nil } // serverTTHeaderHandler implement remote.MetaHandler @@ -147,16 +153,19 @@ func (sh *serverTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Messa strInfo := transInfo.TransStrInfo() intInfo[transmeta.MsgType] = strconv.Itoa(int(msg.MessageType())) + InjectBizStatusError(strInfo, ri.Invocation().BizStatusErr()) + return ctx, nil +} - if bizErr := ri.Invocation().BizStatusErr(); bizErr != nil { - strInfo[bizStatus] = strconv.Itoa(int(bizErr.BizStatusCode())) - strInfo[bizMessage] = bizErr.BizMessage() - if len(bizErr.BizExtra()) != 0 { - strInfo[bizExtra], _ = utils.Map2JSONStr(bizErr.BizExtra()) - } +func InjectBizStatusError(strInfo map[string]string, bizErr kerrors.BizStatusErrorIface) { + if bizErr == nil { + return + } + strInfo[bizStatus] = strconv.Itoa(int(bizErr.BizStatusCode())) + strInfo[bizMessage] = bizErr.BizMessage() + if len(bizErr.BizExtra()) != 0 { + strInfo[bizExtra], _ = utils.Map2JSONStr(bizErr.BizExtra()) } - - return ctx, nil } func isTTHeader(msg remote.Message) bool { diff --git a/pkg/transmeta/ttheader_test.go b/pkg/transmeta/ttheader_test.go index c01963e4dd..20b4fa0aa5 100644 --- a/pkg/transmeta/ttheader_test.go +++ b/pkg/transmeta/ttheader_test.go @@ -24,6 +24,7 @@ import ( "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -132,3 +133,66 @@ func TestTTHeaderServerWriteMetainfo(t *testing.T) { kvs = msg.TransInfo().TransIntInfo() test.Assert(t, kvs[transmeta.MsgType] == strconv.Itoa(int(remote.Call))) } + +func TestInjectBizStatusError(t *testing.T) { + t.Run("nil-err", func(t *testing.T) { + strInfo := map[string]string{} + InjectBizStatusError(strInfo, nil) + test.Assert(t, len(strInfo) == 0) + }) + t.Run("biz-err", func(t *testing.T) { + strInfo := map[string]string{} + bizErr := kerrors.NewBizStatusErrorWithExtra(int32(100), "biz-err", map[string]string{"k1": "v1"}) + InjectBizStatusError(strInfo, bizErr) + test.Assert(t, len(strInfo) == 3) + test.Assert(t, strInfo[bizStatus] == "100") + test.Assert(t, strInfo[bizMessage] == "biz-err") + test.Assert(t, strInfo[bizExtra] == `{"k1":"v1"}`) + }) +} + +func TestParseBizStatusErrorToRPCInfo(t *testing.T) { + t.Run("no-biz-err", func(t *testing.T) { + strInfo := map[string]string{} + ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, nil) + err := ParseBizStatusErrorToRPCInfo(strInfo, ri) + test.Assert(t, err == nil) + }) + t.Run("biz-err", func(t *testing.T) { + strInfo := map[string]string{ + bizStatus: "100", + bizMessage: "biz-err", + bizExtra: `{"k1":"v1"}`, + } + ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, nil) + err := ParseBizStatusErrorToRPCInfo(strInfo, ri) + test.Assert(t, err == nil) + bizErr := ri.Invocation().BizStatusErr() + test.Assert(t, bizErr != nil) + test.Assert(t, bizErr.BizStatusCode() == 100) + test.Assert(t, bizErr.BizMessage() == "biz-err") + test.Assert(t, bizErr.BizExtra()["k1"] == "v1") + }) + t.Run("biz-err:code=0", func(t *testing.T) { + strInfo := map[string]string{ + bizStatus: "0", + bizMessage: "biz-err", + bizExtra: `{"k1":"v1"}`, + } + ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, nil) + err := ParseBizStatusErrorToRPCInfo(strInfo, ri) + test.Assert(t, err == nil) + bizErr := ri.Invocation().BizStatusErr() + test.Assert(t, bizErr == nil) + }) + t.Run("invalid-extra", func(t *testing.T) { + strInfo := map[string]string{ + bizStatus: "100", + bizMessage: "biz-err", + bizExtra: `invalid`, + } + ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, nil) + err := ParseBizStatusErrorToRPCInfo(strInfo, ri) + test.Assert(t, err != nil) + }) +} diff --git a/server/option_advanced.go b/server/option_advanced.go index cf959a8153..b7bf383592 100644 --- a/server/option_advanced.go +++ b/server/option_advanced.go @@ -263,3 +263,11 @@ func WithProfilerMessageTagging(tagging remote.MessageTagging) Option { } }} } + +// WithTTHeaderFrameMetaHandler sets the FrameMetaHandler for TTHeader Streaming +func WithTTHeaderFrameMetaHandler(h remote.FrameMetaHandler) Option { + return Option{F: func(o *internal_server.Options, di *utils.Slice) { + di.Push(fmt.Sprintf("WithTTHeaderFrameMetaHandler(%T)", h)) + o.RemoteOpt.TTHeaderFrameMetaHandler = h + }} +} diff --git a/server/option_test.go b/server/option_test.go index 1a9609b32e..0eebad36e7 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -293,7 +293,10 @@ func TestMuxTransportOption(t *testing.T) { err = svr1.Run() test.Assert(t, err == nil, err) iSvr1 := svr1.(*server) - test.DeepEqual(t, iSvr1.opt.RemoteOpt.SvrHandlerFactory, detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory())) + test.DeepEqual(t, iSvr1.opt.RemoteOpt.SvrHandlerFactory, + detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), + netpoll.NewTTHeaderStreamingSvrTransHandlerFactory(), + nphttp2.NewSvrTransHandlerFactory())) svr2 := NewServer(WithMuxTransport()) time.AfterFunc(100*time.Millisecond, func() { diff --git a/tool/internal_pkg/generator/streaming.go b/tool/internal_pkg/generator/streaming.go new file mode 100644 index 0000000000..1794769916 --- /dev/null +++ b/tool/internal_pkg/generator/streaming.go @@ -0,0 +1,40 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generator + +import ( + "fmt" + "strings" +) + +const ( + GRPC = "GRPC" // transport.GRPC + TTHeader = "TTHeader" // transport.TTHeader + AnnotationKeyStreamingTransport = "streaming.transport" +) + +var supportedStreamingTransport = map[string]string{ + strings.ToUpper(GRPC): GRPC, + strings.ToUpper(TTHeader): TTHeader, +} + +// GetStreamingTransport converts the transport string to the corresponding symbol name in kitex source code. +func GetStreamingTransport(transport string) (string, error) { + transportUppercase := strings.ToUpper(transport) + if kitexTransport, exists := supportedStreamingTransport[transportUppercase]; exists { + return kitexTransport, nil + } + return "", fmt.Errorf("unsupported streaming.transport: %s", transport) +} diff --git a/tool/internal_pkg/generator/streaming_test.go b/tool/internal_pkg/generator/streaming_test.go new file mode 100644 index 0000000000..1bf203f5d8 --- /dev/null +++ b/tool/internal_pkg/generator/streaming_test.go @@ -0,0 +1,43 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generator + +import ( + "testing" + + "github.com/cloudwego/thriftgo/pkg/test" +) + +func TestGetStreamingTransport(t *testing.T) { + t.Run("grpc", func(t *testing.T) { + tp, err := GetStreamingTransport("grpc") + test.Assert(t, err == nil, err) + test.Assert(t, tp == GRPC) + }) + t.Run("gRPC", func(t *testing.T) { + tp, err := GetStreamingTransport("gRPC") + test.Assert(t, err == nil, err) + test.Assert(t, tp == GRPC) + }) + t.Run("ttheader", func(t *testing.T) { + tp, err := GetStreamingTransport("ttheader") + test.Assert(t, err == nil, err) + test.Assert(t, tp == TTHeader) + }) + t.Run("unknown", func(t *testing.T) { + _, err := GetStreamingTransport("xxx") + test.Assert(t, err != nil, err) + }) +} diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index 367ac997e4..a042b29068 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -106,6 +106,7 @@ type ServiceInfo struct { Protocol string HandlerReturnKeepResp bool UseThriftReflection bool + StreamingTransport string } // AllMethods returns all methods that the service have. diff --git a/tool/internal_pkg/pluginmode/protoc/plugin.go b/tool/internal_pkg/pluginmode/protoc/plugin.go index 6394137e3d..a0ec0292e5 100644 --- a/tool/internal_pkg/pluginmode/protoc/plugin.go +++ b/tool/internal_pkg/pluginmode/protoc/plugin.go @@ -20,6 +20,7 @@ import ( "fmt" "path" "path/filepath" + "regexp" "strings" "text/template" @@ -264,6 +265,12 @@ func (pp *protocPlugin) convertTypes(file *protogen.File) (ss []*generator.Servi ServiceName: service.GoName, RawServiceName: string(service.Desc.Name()), } + if protocol, err := parseStreamingTransport(service.Comments); err == nil { + si.StreamingTransport = protocol + } else { + panic(fmt.Errorf("parseStreamingTransport for service %s failed: %v", si.ServiceName, err)) + } + si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName } for _, m := range service.Methods { req := pp.convertParameter(m.Input, "Req") @@ -297,9 +304,19 @@ func (pp *protocPlugin) convertTypes(file *protogen.File) (ss []*generator.Servi if pp.Config.CombineService && len(file.Services) > 0 { var svcs []*generator.ServiceInfo var methods []*generator.MethodInfo + var streamingTransport string for _, s := range ss { svcs = append(svcs, s) methods = append(methods, s.AllMethods()...) + if streamingTransport == "" { + streamingTransport = s.StreamingTransport + } else if streamingTransport != s.StreamingTransport { + log.Warnf("[WARN] service %s has a different streaming transport %s\n", + s.ServiceName, s.StreamingTransport) + } + } + if streamingTransport == "" { + streamingTransport = generator.GRPC // default } // check method name conflict mm := make(map[string]*generator.MethodInfo) @@ -319,12 +336,13 @@ func (pp *protocPlugin) convertTypes(file *protogen.File) (ss []*generator.Servi } svcName := pp.getCombineServiceName("CombineService", ss) si := &generator.ServiceInfo{ - PkgInfo: pi, - ServiceName: svcName, - RawServiceName: svcName, - CombineServices: svcs, - Methods: methods, - HasStreaming: hasStreaming, + PkgInfo: pi, + ServiceName: svcName, + RawServiceName: svcName, + CombineServices: svcs, + Methods: methods, + HasStreaming: hasStreaming, + StreamingTransport: streamingTransport, } si.ServiceTypeName = func() string { return si.ServiceName } ss = append(ss, si) @@ -332,6 +350,19 @@ func (pp *protocPlugin) convertTypes(file *protogen.File) (ss []*generator.Servi return } +var streamingTransportPattern = regexp.MustCompile(`@` + generator.AnnotationKeyStreamingTransport + `=(\w*)`) + +// Protobuf does not support annotation, so we parse it from leading comments in the format of: +// @streaming.transport= +func parseStreamingTransport(comments protogen.CommentSet) (string, error) { + match := streamingTransportPattern.FindAllStringSubmatch(string(comments.Leading), 1) + if len(match) == 0 { + return generator.GRPC, nil + } + protocol := match[0][1] + return generator.GetStreamingTransport(protocol) +} + // BuildStreaming builds protobuf MethodInfo.Streaming as for Thrift, to simplify codegen func BuildStreaming(mi *generator.MethodInfo, serviceHasStreaming bool) { s := &streaming.Streaming{ diff --git a/tool/internal_pkg/pluginmode/protoc/plugin_test.go b/tool/internal_pkg/pluginmode/protoc/plugin_test.go new file mode 100644 index 0000000000..e4523ccd59 --- /dev/null +++ b/tool/internal_pkg/pluginmode/protoc/plugin_test.go @@ -0,0 +1,56 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protoc + +import ( + "testing" + + "google.golang.org/protobuf/compiler/protogen" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" +) + +func Test_parseStreamingTransport(t *testing.T) { + t.Run("grpc", func(t *testing.T) { + comments := protogen.CommentSet{ + Leading: "// @streaming.transport=grpc", + } + tp, err := parseStreamingTransport(comments) + test.Assert(t, err == nil, err) + test.Assert(t, tp == generator.GRPC) + }) + t.Run("ttheader", func(t *testing.T) { + comments := protogen.CommentSet{ + Leading: "// @streaming.transport=ttheader", + } + tp, err := parseStreamingTransport(comments) + test.Assert(t, err == nil, err) + test.Assert(t, tp == generator.TTHeader) + }) + t.Run("no-comment", func(t *testing.T) { + comments := protogen.CommentSet{} + tp, err := parseStreamingTransport(comments) + test.Assert(t, err == nil, err) + test.Assert(t, tp == generator.GRPC) + }) + t.Run("unknown", func(t *testing.T) { + comments := protogen.CommentSet{ + Leading: "// @streaming.transport=xxx", + } + _, err := parseStreamingTransport(comments) + test.Assert(t, err != nil, err) + }) +} diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index ef1576df8f..3116cecf90 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -138,8 +138,9 @@ func (c *converter) copyTreeWithRef(ast *parser.Thrift, ref string) *parser.Thri for _, s := range ast.Services { ss := &parser.Service{ - Name: s.Name, - Extends: s.Extends, + Name: s.Name, + Extends: s.Extends, + Annotations: c.copyAnnotations(s.Annotations), } for _, f := range s.Functions { ff := c.copyFunctionWithRef(f, ref) @@ -368,6 +369,12 @@ func (c *converter) convertTypes(req *plugin.Request) error { hasStreaming = hasStreaming || s.HasStreaming methods = append(methods, s.AllMethods()...) } + + streamingTransport, err := getCombineServiceStreamingTransport(all[ast.Filename]) + if err != nil { + return err + } + // check method name conflict mm := make(map[string]*generator.MethodInfo) for _, m := range methods { @@ -378,13 +385,14 @@ func (c *converter) convertTypes(req *plugin.Request) error { } svcName := c.getCombineServiceName("CombineService", all[ast.Filename]) si := &generator.ServiceInfo{ - PkgInfo: pi, - ServiceName: svcName, - RawServiceName: svcName, - CombineServices: svcs, - Methods: methods, - ServiceFilePath: ast.Filename, - HasStreaming: hasStreaming, + PkgInfo: pi, + ServiceName: svcName, + RawServiceName: svcName, + CombineServices: svcs, + Methods: methods, + ServiceFilePath: ast.Filename, + HasStreaming: hasStreaming, + StreamingTransport: streamingTransport, } if c.IsHessian2() { @@ -403,6 +411,24 @@ func (c *converter) convertTypes(req *plugin.Request) error { return nil } +// getCombineServiceStreamingTransport returns the common streaming transport for all services +// If the streaming transports are not the same, it returns an error +func getCombineServiceStreamingTransport(services []*generator.ServiceInfo) (string, error) { + var streamingTransport string + for _, s := range services { + if streamingTransport == "" { + streamingTransport = s.StreamingTransport + } else if streamingTransport != s.StreamingTransport { + return "", fmt.Errorf("service %s has different streaming transport: %s", + s.ServiceName, s.StreamingTransport) + } + } + if streamingTransport == "" { + streamingTransport = generator.GRPC // default + } + return streamingTransport, nil +} + func (c *converter) fixStreamingForExtendedServices(ast *parser.Thrift, all ast2svc) { for i, svc := range ast.Services { if svc.Extends == "" { @@ -415,14 +441,18 @@ func (c *converter) fixStreamingForExtendedServices(ast *parser.Thrift, all ast2 } } -func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*generator.ServiceInfo, error) { - si := &generator.ServiceInfo{ +func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (si *generator.ServiceInfo, err error) { + si = &generator.ServiceInfo{ PkgInfo: pkg, ServiceName: svc.GoName().String(), RawServiceName: svc.Name, } si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName } + if si.StreamingTransport, err = parseStreamingTransport(svc.Annotations); err != nil { + return nil, err + } + for _, f := range svc.Functions() { if strings.HasPrefix(f.Name, "_") { continue @@ -442,6 +472,23 @@ func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*ge return si, nil } +func parseStreamingTransport(annotations parser.Annotations) (string, error) { + for _, anno := range annotations { + if anno.Key != generator.AnnotationKeyStreamingTransport { + continue + } + if len(anno.Values) != 1 { + return "", fmt.Errorf("invalid streaming.transport (should be exact 1 value): %v", anno.Values) + } + if protocol, err := generator.GetStreamingTransport(anno.Values[0]); err != nil { + return "", err + } else { + return protocol, nil + } + } + return generator.GRPC, nil // if not specified +} + func (c *converter) makeMethod(si *generator.ServiceInfo, f *golang.Function) (*generator.MethodInfo, error) { st, err := streaming.ParseStreaming(f.Function) if err != nil { diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor_test.go b/tool/internal_pkg/pluginmode/thriftgo/convertor_test.go new file mode 100644 index 0000000000..807b33c54e --- /dev/null +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor_test.go @@ -0,0 +1,102 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package thriftgo + +import ( + "testing" + + "github.com/cloudwego/thriftgo/parser" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/tool/internal_pkg/generator" +) + +func Test_parseStreamingTransport(t *testing.T) { + t.Run("no-annotation", func(t *testing.T) { + tp, err := parseStreamingTransport(nil) + test.Assert(t, err == nil, err) + test.Assert(t, tp == generator.GRPC, tp) + }) + t.Run("wrong-annotation:no-value", func(t *testing.T) { + anno := parser.Annotations{ + &parser.Annotation{ + Key: generator.AnnotationKeyStreamingTransport, + Values: []string{}, + }, + } + _, err := parseStreamingTransport(anno) + test.Assert(t, err != nil, err) + }) + t.Run("wrong-annotation:multiple-values", func(t *testing.T) { + anno := parser.Annotations{ + &parser.Annotation{ + Key: generator.AnnotationKeyStreamingTransport, + Values: []string{"1", "2"}, + }, + } + _, err := parseStreamingTransport(anno) + test.Assert(t, err != nil, err) + }) + t.Run("wrong-annotation:wrong-value", func(t *testing.T) { + anno := parser.Annotations{ + &parser.Annotation{ + Key: generator.AnnotationKeyStreamingTransport, + Values: []string{"xxx"}, + }, + } + _, err := parseStreamingTransport(anno) + test.Assert(t, err != nil, err) + }) + t.Run("right-annotation:ttheader", func(t *testing.T) { + anno := parser.Annotations{ + &parser.Annotation{ + Key: generator.AnnotationKeyStreamingTransport, + Values: []string{generator.TTHeader}, + }, + } + tp, err := parseStreamingTransport(anno) + test.Assert(t, err == nil, err) + test.Assert(t, tp == generator.TTHeader, tp) + }) +} + +func Test_getCombineServiceStreamingTransport(t *testing.T) { + t.Run("both-grpc", func(t *testing.T) { + services := []*generator.ServiceInfo{ + {ServiceName: "A", StreamingTransport: generator.GRPC}, + {ServiceName: "B", StreamingTransport: generator.GRPC}, + } + tp, err := getCombineServiceStreamingTransport(services) + test.Assert(t, err == nil, err) + test.Assert(t, tp == generator.GRPC, tp) + }) + t.Run("both-ttheader", func(t *testing.T) { + services := []*generator.ServiceInfo{ + {ServiceName: "A", StreamingTransport: generator.TTHeader}, + {ServiceName: "B", StreamingTransport: generator.TTHeader}, + } + tp, err := getCombineServiceStreamingTransport(services) + test.Assert(t, err == nil, err) + test.Assert(t, tp == generator.TTHeader, tp) + }) + t.Run("diff", func(t *testing.T) { + services := []*generator.ServiceInfo{ + {ServiceName: "A", StreamingTransport: generator.GRPC}, + {ServiceName: "B", StreamingTransport: generator.TTHeader}, + } + _, err := getCombineServiceStreamingTransport(services) + test.Assert(t, err != nil, err) + }) +} diff --git a/tool/internal_pkg/tpl/client.go b/tool/internal_pkg/tpl/client.go index 666e966898..b1a6d40641 100644 --- a/tool/internal_pkg/tpl/client.go +++ b/tool/internal_pkg/tpl/client.go @@ -83,7 +83,7 @@ func NewClient(destService string, opts ...client.Option) (Client, error) { {{template "@client.go-NewClient-option" .}} {{if and (eq $.Codec "protobuf") .HasStreaming}}{{/* Thrift Streaming only in StreamClient */}} - options = append(options, client.WithTransportProtocol(transport.GRPC)) + options = append(options, client.WithTransportProtocol(transport.{{.StreamingTransport}})) {{end}} options = append(options, opts...) @@ -139,7 +139,7 @@ func NewStreamClient(destService string, opts ...streamclient.Option) (StreamCli var options []client.Option options = append(options, client.WithDestService(destService)) {{- template "@client.go-NewStreamClient-option" .}} - options = append(options, client.WithTransportProtocol(transport.GRPC)) + options = append(options, client.WithTransportProtocol(transport.{{.StreamingTransport}})) options = append(options, streamclient.GetClientOptions(opts)...) kc, err := client.NewClient(serviceInfoForStreamClient(), options...)