Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support min timeout with breaker #4070

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 23 additions & 2 deletions core/stores/redis/breakerhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package redis

import (
"context"
"errors"
"time"

red "github.com/redis/go-redis/v9"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/timex"
)

const minTimeout = time.Millisecond * 100

var ignoreCmds = map[string]lang.PlaceholderType{
"blpop": {},
}
Expand All @@ -26,16 +31,32 @@ func (h breakerHook) ProcessHook(next red.ProcessHook) red.ProcessHook {
return next(ctx, cmd)
}

start := timex.Now()
return h.brk.DoWithAcceptable(func() error {
return next(ctx, cmd)
}, acceptable)
}, protectedAcceptable(start))
}
}

func (h breakerHook) ProcessPipelineHook(next red.ProcessPipelineHook) red.ProcessPipelineHook {
return func(ctx context.Context, cmds []red.Cmder) error {
start := timex.Now()
return h.brk.DoWithAcceptable(func() error {
return next(ctx, cmds)
}, acceptable)
}, protectedAcceptable(start))
}
}

func acceptable(err error) bool {
return err == nil || errors.Is(err, red.Nil) || errors.Is(err, context.Canceled)
}

func protectedAcceptable(start time.Duration) breaker.Acceptable {
return func(err error) bool {
if acceptable(err) {
return true
}

return errors.Is(err, context.DeadlineExceeded) && timex.Since(start) < minTimeout
}
}
4 changes: 0 additions & 4 deletions core/stores/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2371,10 +2371,6 @@ func withHook(hook red.Hook) Option {
}
}

func acceptable(err error) bool {
return err == nil || errors.Is(err, red.Nil) || errors.Is(err, context.Canceled)
}

func getRedis(r *Redis) (RedisNode, error) {
switch r.Type {
case ClusterType:
Expand Down
12 changes: 8 additions & 4 deletions core/stores/sqlx/sqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
endSpan(span, err)
}()

err = db.brk.DoWithAcceptable(func() error {
err = db.run(func() error {
var conn *sql.DB
conn, err = db.connProv()
if err != nil {
Expand Down Expand Up @@ -148,7 +148,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
endSpan(span, err)
}()

err = db.brk.DoWithAcceptable(func() error {
err = db.run(func() error {
var conn *sql.DB
conn, err = db.connProv()
if err != nil {
Expand Down Expand Up @@ -256,7 +256,7 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
endSpan(span, err)
}()

err = db.brk.DoWithAcceptable(func() error {
err = db.run(func() error {
return transact(ctx, db, db.beginTx, fn)
}, db.acceptable)
if errors.Is(err, breaker.ErrServiceUnavailable) {
Expand Down Expand Up @@ -287,7 +287,7 @@ func (db *commonSqlConn) acceptable(err error) bool {
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
q string, args ...any) (err error) {
var scanFailed bool
err = db.brk.DoWithAcceptable(func() error {
err = db.run(func() error {
conn, err := db.connProv()
if err != nil {
db.onError(ctx, err)
Expand All @@ -311,6 +311,10 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return
}

func (db *commonSqlConn) run(fn func() error, acceptable breaker.Acceptable) error {
return runWithBreaker(db.brk, fn, acceptable)
}

// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {
Expand Down
8 changes: 6 additions & 2 deletions core/stores/sqlx/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result,
endSpan(span, err)
}()

err = s.brk.DoWithAcceptable(func() error {
err = s.run(func() error {
result, err = execStmt(ctx, s.stmt, s.query, args...)
return err
}, func(err error) bool {
Expand Down Expand Up @@ -141,7 +141,7 @@ func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any)
func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner) error,
v any, args ...any) error {
var scanFailed bool
err := s.brk.DoWithAcceptable(func() error {
err := s.run(func() error {
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
err := scanFn(v, rows)
if isScanFailed(err) {
Expand All @@ -159,6 +159,10 @@ func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner)
return err
}

func (s statement) run(fn func() error, acceptable breaker.Acceptable) error {
return runWithBreaker(s.brk, fn, acceptable)
}

// DisableLog disables logging of sql statements, includes info and slow logs.
func DisableLog() {
logSql.Set(false)
Expand Down
15 changes: 15 additions & 0 deletions core/stores/sqlx/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ import (
"strings"
"time"

"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/timex"
)

const minTimeout = time.Millisecond * 100

var errUnbalancedEscape = errors.New("no char after escape char")

func desensitize(datasource string) string {
Expand Down Expand Up @@ -148,6 +152,17 @@ func logSqlError(ctx context.Context, stmt string, err error) {
}
}

func runWithBreaker(brk breaker.Breaker, fn func() error, acceptable breaker.Acceptable) error {
start := timex.Now()
return brk.DoWithAcceptable(fn, func(err error) bool {
if acceptable(err) {
return true
}

return errors.Is(err, context.DeadlineExceeded) && timex.Since(start) < minTimeout
})
}

func writeValue(buf *strings.Builder, arg any) {
switch v := arg.(type) {
case bool:
Expand Down
14 changes: 13 additions & 1 deletion zrpc/internal/clientinterceptors/breakerinterceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,30 @@

import (
"context"
"errors"
"path"
"time"

"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/timex"
"github.com/zeromicro/go-zero/zrpc/internal/codes"
"google.golang.org/grpc"
)

const minTimeout = time.Millisecond * 100

// BreakerInterceptor is an interceptor that acts as a circuit breaker.
func BreakerInterceptor(ctx context.Context, method string, req, reply any,
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
breakerName := path.Join(cc.Target(), method)
start := timex.Now()
return breaker.DoWithAcceptable(breakerName, func() error {
return invoker(ctx, method, req, reply, cc, opts...)
}, codes.Acceptable)
}, func(err error) bool {
if codes.Acceptable(err) {
return true
}

return errors.Is(err, context.DeadlineExceeded) && timex.Since(start) < minTimeout

Check warning on line 29 in zrpc/internal/clientinterceptors/breakerinterceptor.go

View check run for this annotation

Codecov / codecov/patch

zrpc/internal/clientinterceptors/breakerinterceptor.go#L29

Added line #L29 was not covered by tests
})
}
17 changes: 17 additions & 0 deletions zrpc/internal/clientinterceptors/breakerinterceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
Expand Down Expand Up @@ -79,3 +80,19 @@ func TestBreakerInterceptor(t *testing.T) {
})
}
}

func TestBreakerTimeout(t *testing.T) {
t.Run("not timeout", func(t *testing.T) {
cc := new(grpc.ClientConn)
for i := 0; i < 1000; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
err := BreakerInterceptor(ctx, "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
return context.DeadlineExceeded
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
cancel()
}
})
}