Skip to content

Commit

Permalink
refactor: no need isClient
Browse files Browse the repository at this point in the history
  • Loading branch information
Marina-Sakai committed Jun 17, 2024
1 parent 7d115c6 commit 78bc78d
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 34 deletions.
20 changes: 10 additions & 10 deletions pkg/generic/generic_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ type Args struct {

var (
_ codecThrift.MessageReaderWithMethodWithContext = (*Args)(nil)
_ codecThrift.MessageWriterWithContext = (*Args)(nil)
_ codecThrift.MessageWriterWithMethodWithContext = (*Args)(nil)
_ WithCodec = (*Args)(nil)
)

Expand All @@ -132,12 +132,12 @@ func (g *Args) GetOrSetBase() interface{} {
}

// Write ...
func (g *Args) Write(ctx context.Context, method string, isClient bool, out thrift.TProtocol) error {
func (g *Args) Write(ctx context.Context, method string, out thrift.TProtocol) error {
if err, ok := g.inner.(error); ok {
return err
}
if w, ok := g.inner.(gthrift.MessageWriter); ok {
return w.Write(ctx, out, g.Request, method, isClient, g.base)
return w.Write(ctx, out, g.Request, method, true, g.base)
}
return fmt.Errorf("unexpected Args writer type: %T", g.inner)
}
Expand All @@ -153,14 +153,14 @@ func (g *Args) WritePb(ctx context.Context, method string, isClient bool) (inter
}

// Read ...
func (g *Args) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) error {
func (g *Args) Read(ctx context.Context, method string, dataLen int, in thrift.TProtocol) error {
if err, ok := g.inner.(error); ok {
return err
}
if rw, ok := g.inner.(gthrift.MessageReader); ok {
g.Method = method
var err error
g.Request, err = rw.Read(ctx, method, isClient, dataLen, in)
g.Request, err = rw.Read(ctx, method, false, dataLen, in)
return err
}
return fmt.Errorf("unexpected Args reader type: %T", g.inner)
Expand Down Expand Up @@ -192,7 +192,7 @@ type Result struct {

var (
_ codecThrift.MessageReaderWithMethodWithContext = (*Result)(nil)
_ codecThrift.MessageWriterWithContext = (*Result)(nil)
_ codecThrift.MessageWriterWithMethodWithContext = (*Result)(nil)
_ WithCodec = (*Result)(nil)
)

Expand All @@ -203,12 +203,12 @@ func (r *Result) SetCodec(inner interface{}) {
}

// Write ...
func (r *Result) Write(ctx context.Context, method string, isClient bool, out thrift.TProtocol) error {
func (r *Result) Write(ctx context.Context, method string, out thrift.TProtocol) error {
if err, ok := r.inner.(error); ok {
return err
}
if w, ok := r.inner.(gthrift.MessageWriter); ok {
return w.Write(ctx, out, r.Success, method, isClient, nil)
return w.Write(ctx, out, r.Success, method, false, nil)
}
return fmt.Errorf("unexpected Result writer type: %T", r.inner)
}
Expand All @@ -224,13 +224,13 @@ func (r *Result) WritePb(ctx context.Context, method string, isClient bool) (int
}

// Read ...
func (r *Result) Read(ctx context.Context, method string, isClient bool, dataLen int, in thrift.TProtocol) error {
func (r *Result) Read(ctx context.Context, method string, dataLen int, in thrift.TProtocol) error {
if err, ok := r.inner.(error); ok {
return err
}
if w, ok := r.inner.(gthrift.MessageReader); ok {
var err error
r.Success, err = w.Read(ctx, method, isClient, dataLen, in)
r.Success, err = w.Read(ctx, method, true, dataLen, in)
return err
}
return fmt.Errorf("unexpected Result reader type: %T", r.inner)
Expand Down
16 changes: 8 additions & 8 deletions pkg/generic/generic_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ func TestGenericService(t *testing.T) {
test.Assert(t, base != nil)
a.SetCodec(struct{}{})
// write not ok
err := a.Write(ctx, method, true, tProto)
err := a.Write(ctx, method, tProto)
test.Assert(t, err.Error() == "unexpected Args writer type: struct {}")

// Write expect
argWriteInner.EXPECT().Write(ctx, tProto, a.Request, a.GetOrSetBase()).Return(nil)
a.SetCodec(argWriteInner)
// write ok
err = a.Write(ctx, method, true, tProto)
err = a.Write(ctx, method, tProto)
test.Assert(t, err == nil, err)
// read not ok
err = a.Read(ctx, method, false, 0, tProto)
err = a.Read(ctx, method, 0, tProto)
test.Assert(t, strings.Contains(err.Error(), "unexpected Args reader type"))
// read ok
a.SetCodec(rInner)
err = a.Read(ctx, method, false, 0, tProto)
err = a.Read(ctx, method, 0, tProto)
test.Assert(t, err == nil)

// Result...
Expand All @@ -79,20 +79,20 @@ func TestGenericService(t *testing.T) {
test.Assert(t, ok == true)

// write not ok
err = r.Write(ctx, method, false, tProto)
err = r.Write(ctx, method, tProto)
test.Assert(t, err.Error() == "unexpected Result writer type: <nil>")
// Write expect
resultWriteInner.EXPECT().Write(ctx, tProto, r.Success, (*gthrift.Base)(nil)).Return(nil).AnyTimes()
r.SetCodec(resultWriteInner)
// write ok
err = r.Write(ctx, method, false, tProto)
err = r.Write(ctx, method, tProto)
test.Assert(t, err == nil)
// read not ok
err = r.Read(ctx, method, true, 0, tProto)
err = r.Read(ctx, method, 0, tProto)
test.Assert(t, strings.Contains(err.Error(), "unexpected Result reader type"))
// read ok
r.SetCodec(rInner)
err = r.Read(ctx, method, true, 0, tProto)
err = r.Read(ctx, method, 0, tProto)
test.Assert(t, err == nil)

r.SetSuccess(nil)
Expand Down
8 changes: 4 additions & 4 deletions pkg/remote/codec/thrift/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ func (c thriftCodec) Name() string {
return serviceinfo.Thrift.String()
}

// MessageWriterWithContext write to thrift.TProtocol
type MessageWriterWithContext interface {
Write(ctx context.Context, method string, isClient bool, oprot thrift.TProtocol) error
// MessageWriterWithMethodWithContext write to thrift.TProtocol
type MessageWriterWithMethodWithContext interface {
Write(ctx context.Context, method string, oprot thrift.TProtocol) error
}

// MessageWriter write to thrift.TProtocol
Expand All @@ -257,7 +257,7 @@ type MessageReader interface {

// MessageReaderWithMethodWithContext read from thrift.TProtocol with method
type MessageReaderWithMethodWithContext interface {
Read(ctx context.Context, method string, isClient bool, dataLen int, oprot thrift.TProtocol) error
Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error
}

type ThriftMsgFastCodec interface {
Expand Down
8 changes: 4 additions & 4 deletions pkg/remote/codec/thrift/thrift_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([
func verifyMarshalBasicThriftDataType(data interface{}) error {
switch data.(type) {
case MessageWriter:
case MessageWriterWithContext:
case MessageWriterWithMethodWithContext:
default:
return errEncodeMismatchMsgType
}
Expand All @@ -96,8 +96,8 @@ func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data in
if err := msg.Write(tProt); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
}
case MessageWriterWithContext:
if err := msg.Write(ctx, method, rpcRole == remote.Client, tProt); err != nil {
case MessageWriterWithMethodWithContext:
if err := msg.Write(ctx, method, tProt); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
}
default:
Expand Down Expand Up @@ -242,7 +242,7 @@ func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method s
}
case MessageReaderWithMethodWithContext:
// methodName is necessary for generic calls to methodInfo from serviceInfo
if err = t.Read(ctx, method, rpcRole == remote.Client, dataLen, tProt); err != nil {
if err = t.Read(ctx, method, dataLen, tProt); err != nil {
return remote.NewTransError(remote.ProtocolError, err)
}
default:
Expand Down
16 changes: 8 additions & 8 deletions pkg/remote/codec/thrift/thrift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ func init() {
}

type mockWithContext struct {
ReadFunc func(ctx context.Context, method string, isClient bool, dataLen int, oprot thrift.TProtocol) error
WriteFunc func(ctx context.Context, method string, isClient bool, oprot thrift.TProtocol) error
ReadFunc func(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error
WriteFunc func(ctx context.Context, method string, oprot thrift.TProtocol) error
}

func (m *mockWithContext) Read(ctx context.Context, method string, isClient bool, dataLen int, oprot thrift.TProtocol) error {
func (m *mockWithContext) Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error {
if m.ReadFunc != nil {
return m.ReadFunc(ctx, method, isClient, dataLen, oprot)
return m.ReadFunc(ctx, method, dataLen, oprot)
}
return nil
}

func (m *mockWithContext) Write(ctx context.Context, method string, isClient bool, oprot thrift.TProtocol) error {
func (m *mockWithContext) Write(ctx context.Context, method string, oprot thrift.TProtocol) error {
if m.WriteFunc != nil {
return m.WriteFunc(ctx, method, isClient, oprot)
return m.WriteFunc(ctx, method, oprot)
}
return nil
}
Expand All @@ -85,7 +85,7 @@ func TestWithContext(t *testing.T) {
t.Run(tb.Name, func(t *testing.T) {
ctx := context.Background()

req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, isClient bool, oprot thrift.TProtocol) error {
req := &mockWithContext{WriteFunc: func(ctx context.Context, method string, oprot thrift.TProtocol) error {
return nil
}}
ink := rpcinfo.NewInvocation("", "mock")
Expand All @@ -98,7 +98,7 @@ func TestWithContext(t *testing.T) {
buf.Flush()

{
resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, isClient bool, dataLen int, oprot thrift.TProtocol) error {
resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error {
return nil
}}
ink := rpcinfo.NewInvocation("", "mock")
Expand Down

0 comments on commit 78bc78d

Please sign in to comment.