Skip to content

Commit

Permalink
fix: pass method and isClient to writer and reader
Browse files Browse the repository at this point in the history
  • Loading branch information
Marina-Sakai committed Jun 4, 2024
1 parent 560e4bc commit 14c6206
Show file tree
Hide file tree
Showing 32 changed files with 229 additions and 437 deletions.
21 changes: 8 additions & 13 deletions client/genericclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func NewClientWithServiceInfo(destService string, g generic.Generic, svcInfo *se
return nil, err
}
cli := &genericServiceClient{
svcInfo: svcInfo,
kClient: kc,
g: g,
}
Expand Down Expand Up @@ -86,33 +87,27 @@ type Client interface {
}

type genericServiceClient struct {
svcInfo *serviceinfo.ServiceInfo
kClient client.Client
g generic.Generic
}

func (gc *genericServiceClient) GenericCall(ctx context.Context, method string, request interface{}, callOptions ...callopt.Option) (response interface{}, err error) {
ctx = client.NewCtxWithCallOptions(ctx, callOptions)
var _args generic.Args
_args := gc.svcInfo.MethodInfo(serviceinfo.GenericService).NewArgs().(*generic.Args)
_args.Method = method
_args.Request = request

mt, err := gc.g.GetMethod(request, method)
if err != nil {
return nil, err
}
codec := gc.g.CodecInfo()
if codec != nil {
codec.SetMethod(mt.Name)
codec.SetIsClient(true)
_args.SetCodec(codec.GetMessageReaderWriter())
}
if mt.Oneway {
return nil, gc.kClient.Call(ctx, mt.Name, &_args, nil)
return nil, gc.kClient.Call(ctx, mt.Name, _args, nil)
}
var _result generic.Result
if codec != nil {
_result.SetCodec(codec.GetMessageReaderWriter())
}
if err = gc.kClient.Call(ctx, mt.Name, &_args, &_result); err != nil {

_result := gc.svcInfo.MethodInfo(serviceinfo.GenericService).NewResult().(*generic.Result)
if err = gc.kClient.Call(ctx, mt.Name, _args, _result); err != nil {
return
}
return _result.GetSuccess(), nil
Expand Down
5 changes: 2 additions & 3 deletions internal/mocks/generic/thrift.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 19 additions & 17 deletions pkg/generic/generic_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (

gproto "github.com/cloudwego/kitex/pkg/generic/proto"
gthrift "github.com/cloudwego/kitex/pkg/generic/thrift"
"github.com/cloudwego/kitex/pkg/remote"
codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift"
"github.com/cloudwego/kitex/pkg/serviceinfo"
)
Expand All @@ -47,6 +46,7 @@ func newServiceInfo(pcType serviceinfo.PayloadCodec, codec serviceinfo.CodecInfo
var svcName string

if codec == nil {
// TODO: support multi-service for binary generic
svcName = serviceinfo.GenericService
methods = map[string]serviceinfo.MethodInfo{
serviceinfo.GenericMethod: serviceinfo.NewMethodInfo(callHandler, newGenericServiceCallArgs, newGenericServiceCallResult, false),
Expand Down Expand Up @@ -113,6 +113,7 @@ var (
_ WithCodec = (*Args)(nil)
)

// Deprecated: it's not used by kitex anymore.
// SetCodec ...
func (g *Args) SetCodec(inner interface{}) {
g.inner = inner
Expand All @@ -126,48 +127,48 @@ func (g *Args) GetOrSetBase() interface{} {
}

// Write ...
func (g *Args) Write(ctx context.Context, out thrift.TProtocol) error {
func (g *Args) Write(ctx context.Context, method string, isClient bool, out thrift.TProtocol) error {
if err, ok := g.inner.(error); ok {
return err
}
if w, ok := g.inner.(gthrift.MessageWriter); ok {
return w.Write(ctx, out, g.Request, g.base)
return w.Write(ctx, out, g.Request, method, isClient, g.base)
}
return fmt.Errorf("unexpected Args writer type: %T", g.inner)
}

func (g *Args) WritePb(ctx context.Context) (interface{}, error) {
func (g *Args) WritePb(ctx context.Context, method string, isClient bool) (interface{}, error) {
if err, ok := g.inner.(error); ok {
return nil, err
}
if w, ok := g.inner.(gproto.MessageWriter); ok {
return w.Write(ctx, g.Request)
return w.Write(ctx, g.Request, method, isClient)
}
return nil, fmt.Errorf("unexpected Args writer type: %T", g.inner)
}

// Read ...
func (g *Args) Read(ctx context.Context, method string, msgType remote.MessageType, dataLen int, in thrift.TProtocol) error {
func (g *Args) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) error {
if err, ok := g.inner.(error); ok {
return err
}
if rw, ok := g.inner.(gthrift.MessageReader); ok {
g.Method = method
var err error
g.Request, err = rw.Read(ctx, method, msgType, dataLen, in)
g.Request, err = rw.Read(ctx, method, isClient, dataLen, in)
return err
}
return fmt.Errorf("unexpected Args reader type: %T", g.inner)
}

func (g *Args) ReadPb(ctx context.Context, method string, in []byte) error {
func (g *Args) ReadPb(ctx context.Context, method string, isClient bool, in []byte) error {
if err, ok := g.inner.(error); ok {
return err
}
if w, ok := g.inner.(gproto.MessageReader); ok {
g.Method = method
var err error
g.Request, err = w.Read(ctx, method, in)
g.Request, err = w.Read(ctx, method, isClient, in)
return err
}
return fmt.Errorf("unexpected Args reader type: %T", g.inner)
Expand All @@ -190,52 +191,53 @@ var (
_ WithCodec = (*Result)(nil)
)

// Deprecated: it's not used by kitex anymore.
// SetCodec ...
func (r *Result) SetCodec(inner interface{}) {
r.inner = inner
}

// Write ...
func (r *Result) Write(ctx context.Context, out thrift.TProtocol) error {
func (r *Result) Write(ctx context.Context, method string, isClient bool, out thrift.TProtocol) error {
if err, ok := r.inner.(error); ok {
return err
}
if w, ok := r.inner.(gthrift.MessageWriter); ok {
return w.Write(ctx, out, r.Success, nil)
return w.Write(ctx, out, r.Success, method, isClient, nil)
}
return fmt.Errorf("unexpected Result writer type: %T", r.inner)
}

func (r *Result) WritePb(ctx context.Context) (interface{}, error) {
func (r *Result) WritePb(ctx context.Context, method string, isClient bool) (interface{}, error) {
if err, ok := r.inner.(error); ok {
return nil, err
}
if w, ok := r.inner.(gproto.MessageWriter); ok {
return w.Write(ctx, r.Success)
return w.Write(ctx, r.Success, method, isClient)
}
return nil, fmt.Errorf("unexpected Result writer type: %T", r.inner)
}

// Read ...
func (r *Result) Read(ctx context.Context, method string, msgType remote.MessageType, dataLen int, in thrift.TProtocol) error {
func (r *Result) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) error {
if err, ok := r.inner.(error); ok {
return err
}
if w, ok := r.inner.(gthrift.MessageReader); ok {
var err error
r.Success, err = w.Read(ctx, method, msgType, dataLen, in)
r.Success, err = w.Read(ctx, method, isClient, dataLen, in)
return err
}
return fmt.Errorf("unexpected Result reader type: %T", r.inner)
}

func (r *Result) ReadPb(ctx context.Context, method string, in []byte) error {
func (r *Result) ReadPb(ctx context.Context, method string, isClient bool, in []byte) error {
if err, ok := r.inner.(error); ok {
return err
}
if w, ok := r.inner.(gproto.MessageReader); ok {
var err error
r.Success, err = w.Read(ctx, method, in)
r.Success, err = w.Read(ctx, method, isClient, in)
return err
}
return fmt.Errorf("unexpected Result reader type: %T", r.inner)
Expand Down
18 changes: 9 additions & 9 deletions pkg/generic/generic_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ func TestGenericService(t *testing.T) {
test.Assert(t, base != nil)
a.SetCodec(struct{}{})
// write not ok
err := a.Write(ctx, tProto)
err := a.Write(ctx, method, true, tProto)
test.Assert(t, err.Error() == "unexpected Args writer type: struct {}")

// Write expect
argWriteInner.EXPECT().Write(ctx, tProto, a.Request, a.GetOrSetBase()).Return(nil)
a.SetCodec(argWriteInner)
// write ok
err = a.Write(ctx, tProto)
test.Assert(t, err == nil)
err = a.Write(ctx, method, true, tProto)
test.Assert(t, err == nil, err)
// read not ok
err = a.Read(ctx, method, -1, 0, tProto)
err = a.Read(ctx, method, false, 0, tProto)
test.Assert(t, strings.Contains(err.Error(), "unexpected Args reader type"))
// read ok
a.SetCodec(rInner)
err = a.Read(ctx, method, -1, 0, tProto)
err = a.Read(ctx, method, false, 0, tProto)
test.Assert(t, err == nil)

// Result...
Expand All @@ -79,20 +79,20 @@ func TestGenericService(t *testing.T) {
test.Assert(t, ok == true)

// write not ok
err = r.Write(ctx, tProto)
err = r.Write(ctx, method, false, tProto)
test.Assert(t, err.Error() == "unexpected Result writer type: <nil>")
// Write expect
resultWriteInner.EXPECT().Write(ctx, tProto, r.Success, (*gthrift.Base)(nil)).Return(nil).AnyTimes()
r.SetCodec(resultWriteInner)
// write ok
err = r.Write(ctx, tProto)
err = r.Write(ctx, method, false, tProto)
test.Assert(t, err == nil)
// read not ok
err = r.Read(ctx, method, -1, 0, tProto)
err = r.Read(ctx, method, true, 0, tProto)
test.Assert(t, strings.Contains(err.Error(), "unexpected Result reader type"))
// read ok
r.SetCodec(rInner)
err = r.Read(ctx, method, -1, 0, tProto)
err = r.Read(ctx, method, true, 0, tProto)
test.Assert(t, err == nil)

r.SetSuccess(nil)
Expand Down
9 changes: 0 additions & 9 deletions pkg/generic/httppbthrift_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ type httpPbThriftCodec struct {
provider DescriptorProvider
pbProvider PbDescriptorProvider
svcName string
method string
}

func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider) *httpPbThriftCodec {
Expand Down Expand Up @@ -102,14 +101,6 @@ func (c *httpPbThriftCodec) GetMessageReaderWriter() interface{} {
return thrift.NewHTTPPbReaderWriter(svcDsc, pbSvcDsc)
}

func (c *httpPbThriftCodec) SetMethod(method string) {
c.method = method
}

// SetIsClient for httpPb generic does nothing because httpPb generic is only for client
func (c *httpPbThriftCodec) SetIsClient(isClient bool) {
}

func (c *httpPbThriftCodec) GetIDLServiceName() string {
return c.svcName
}
Expand Down
22 changes: 4 additions & 18 deletions pkg/generic/httpthrift_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ type httpThriftCodec struct {
dynamicgoEnabled bool
useRawBodyForHTTPResp bool
svcName string
method string
}

func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec {
Expand Down Expand Up @@ -84,24 +83,19 @@ func (c *httpThriftCodec) update() {
func (c *httpThriftCodec) GetMessageReaderWriter() interface{} {
svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
if !ok {
return errors.New("get method name failed, no ServiceDescriptor")
return errors.New("get parser ServiceDescriptor failed")
}
rw := thrift.NewHTTPReaderWriter(svcDsc)
if err := c.configureHTTPRequestWriter(rw.WriteHTTPRequest); err != nil {
return err
}
c.configureHTTPRequestWriter(rw.WriteHTTPRequest)
c.configureHTTPResponseReader(rw.ReadHTTPResponse)
return rw
}

func (c *httpThriftCodec) configureHTTPRequestWriter(writer *thrift.WriteHTTPRequest) error {
func (c *httpThriftCodec) configureHTTPRequestWriter(writer *thrift.WriteHTTPRequest) {
writer.SetBinaryWithBase64(c.binaryWithBase64)
if c.dynamicgoEnabled {
if err := writer.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase, c.method); err != nil {
return err
}
writer.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase)
}
return nil
}

func (c *httpThriftCodec) configureHTTPResponseReader(reader *thrift.ReadHTTPResponse) {
Expand All @@ -112,14 +106,6 @@ func (c *httpThriftCodec) configureHTTPResponseReader(reader *thrift.ReadHTTPRes
}
}

func (c *httpThriftCodec) SetMethod(method string) {
c.method = method
}

// SetIsClient for http generic does nothing because http generic is only for client
func (c *httpThriftCodec) SetIsClient(isClient bool) {
}

func (c *httpThriftCodec) GetIDLServiceName() string {
return c.svcName
}
Expand Down
14 changes: 1 addition & 13 deletions pkg/generic/httpthrift_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package generic
import (
"bytes"
"net/http"
"strings"
"testing"

"github.com/bytedance/sonic"
Expand Down Expand Up @@ -69,9 +68,6 @@ func TestHttpThriftCodec(t *testing.T) {
_, ok := rw.(error)
test.Assert(t, !ok)

htc.SetMethod(method.Name)
test.Assert(t, htc.method == "BinaryEcho")

rw = htc.GetMessageReaderWriter()
_, ok = rw.(gthrift.MessageWriter)
test.Assert(t, ok)
Expand Down Expand Up @@ -104,15 +100,7 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) {
test.Assert(t, htc.GetIDLServiceName() == "ExampleService")

rw := htc.GetMessageReaderWriter()
err, ok := rw.(error)
test.Assert(t, ok)
test.Assert(t, strings.Contains(err.Error(), "missing method"))

htc.SetMethod(method.Name)
test.Assert(t, htc.method == "BinaryEcho")

rw = htc.GetMessageReaderWriter()
_, ok = rw.(gthrift.MessageWriter)
_, ok := rw.(gthrift.MessageWriter)
test.Assert(t, ok)
_, ok = rw.(gthrift.MessageReader)
test.Assert(t, ok)
Expand Down
Loading

0 comments on commit 14c6206

Please sign in to comment.