Skip to content

Commit

Permalink
feat: preventing sql read amplification by logging concurrent same qu…
Browse files Browse the repository at this point in the history
…eries
  • Loading branch information
kevwan committed Apr 10, 2024
1 parent 2a7ada9 commit 8c3843a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 3 deletions.
70 changes: 70 additions & 0 deletions core/stores/sqlx/readamplication.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package sqlx

import (
"sync"

"github.com/zeromicro/go-zero/core/logx"
)

const (
concurrencyThreshold = 3
logInterval = 60 * 1000 // 1 minute
)

var logger = logx.NewLessLogger(logInterval)

type (
concurrentReads struct {
reads map[string]*queryReference
lock sync.Mutex
}

queryReference struct {
concurrency uint32
maxConcurrency uint32
}
)

func newConcurrentReads() *concurrentReads {
return &concurrentReads{
reads: make(map[string]*queryReference),
}
}

func (r *concurrentReads) add(query string) {
r.lock.Lock()
defer r.lock.Unlock()

if ref, ok := r.reads[query]; ok {
ref.concurrency++
if ref.maxConcurrency < ref.concurrency {
ref.maxConcurrency = ref.concurrency
}
} else {
r.reads[query] = &queryReference{
concurrency: 1,
maxConcurrency: 1,
}
}
}

func (r *concurrentReads) remove(query string) {
r.lock.Lock()
defer r.lock.Unlock()
ref, ok := r.reads[query]
if !ok {
return
}

if ref.concurrency > 1 {
ref.concurrency--
return
}

// last reference to remove
delete(r.reads, query)
if ref.maxConcurrency >= concurrencyThreshold {
logger.Errorf("sql query amplified, query: %q, maxConcurrency: %d",
query, ref.maxConcurrency)
}
}
33 changes: 33 additions & 0 deletions core/stores/sqlx/readamplification_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package sqlx

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestName(t *testing.T) {
const (
query = "select foo"
times = 10
)
cr := newConcurrentReads()
assert.NotPanics(t, func() {
cr.remove(query)
})

for i := 0; i < times; i++ {
cr.add(query)
}

ref := cr.reads[query]
assert.Equal(t, uint32(times), ref.concurrency)

for i := 0; i < times; i++ {
cr.remove(query)
}

// just removed, not decremented
assert.Equal(t, uint32(1), ref.concurrency)
assert.Equal(t, uint32(times), ref.maxConcurrency)
}
9 changes: 6 additions & 3 deletions core/stores/sqlx/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ import (
const defaultSlowThreshold = time.Millisecond * 500

var (
slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
logSql = syncx.ForAtomicBool(true)
logSlowSql = syncx.ForAtomicBool(true)
slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
logSql = syncx.ForAtomicBool(true)
logSlowSql = syncx.ForAtomicBool(true)
concurrentQueries = newConcurrentReads()
)

type (
Expand Down Expand Up @@ -266,6 +267,7 @@ func (n nilGuard) finish(_ context.Context, _ error) {
}

func (e *realSqlGuard) finish(ctx context.Context, err error) {
concurrentQueries.remove(e.stmt)
duration := timex.Since(e.startTime)
if duration > slowThreshold.Load() {
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt)
Expand All @@ -289,6 +291,7 @@ func (e *realSqlGuard) start(q string, args ...any) error {

e.stmt = stmt
e.startTime = timex.Now()
concurrentQueries.add(stmt)

return nil
}

0 comments on commit 8c3843a

Please sign in to comment.