Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 174 additions & 22 deletions go/pools/smartconnpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ import (
"vitess.io/vitess/go/vt/vterrors"
)

// The states a pool can be in.
const (
UNINITIALIZED = iota
OPENED
CLOSED
)

var (
// ErrTimeout is returned if a connection get times out.
ErrTimeout = vterrors.New(vtrpcpb.Code_RESOURCE_EXHAUSTED, "connection pool timed out")
Expand Down Expand Up @@ -124,8 +131,12 @@ type ConnPool[C Connection] struct {
capacity atomic.Int64

// workers is a waitgroup for all the currently running worker goroutines
workers sync.WaitGroup
close chan struct{}
workers sync.WaitGroup
close chan struct{}

// state represents the state the pool is in: uninitialized, open, or closed.
state atomic.Uint32

capacityMu sync.Mutex

config struct {
Expand Down Expand Up @@ -193,6 +204,19 @@ func (pool *ConnPool[C]) open() {
// The expire worker takes care of removing from the waiter list any clients whose
// context has been cancelled.
pool.runWorker(pool.close, 100*time.Millisecond, func(_ time.Time) bool {
if pool.IsClosed() {
// Clean up any waiters that may have been added after the pool was closed
pool.wait.expire(true)

// If there are no more active connections, we can close the channel and stop
// the workers
if pool.active.Load() == 0 {
close(pool.close)
}

return true
}

maybeStarving := pool.wait.expire(false)

// Do not allow connections to starve; if there's waiters in the queue
Expand Down Expand Up @@ -234,8 +258,8 @@ func (pool *ConnPool[C]) open() {
// Open starts the background workers that manage the pool and gets it ready
// to start serving out connections.
func (pool *ConnPool[C]) Open(connect Connector[C], refresh RefreshCheck) *ConnPool[C] {
if pool.close != nil {
// already open
if !pool.state.CompareAndSwap(UNINITIALIZED, OPENED) {
// already open or closed
return pool
}

Expand Down Expand Up @@ -263,20 +287,41 @@ func (pool *ConnPool[C]) CloseWithContext(ctx context.Context) error {
pool.capacityMu.Lock()
defer pool.capacityMu.Unlock()

if pool.close == nil || pool.capacity.Load() == 0 {
// already closed
if !pool.state.CompareAndSwap(OPENED, CLOSED) {
// Already closed or uninitialized
return nil
}

// close all the connections in the pool; if we time out while waiting for
// users to return our connections, we still want to finish the shutdown
// for the pool
err := pool.setCapacity(ctx, 0)
// close all the connections in the pool

close(pool.close)
pool.workers.Wait()
pool.close = nil
return err
newcap := int64(0)
oldcap := pool.capacity.Swap(newcap)
if oldcap == newcap {
return nil
}

// close connections until we're under capacity
for {
// make sure there's no clients waiting for connections because they won't be returned in the future
pool.wait.expire(true)

// try closing from connections which are currently idle in the stacks
conn := pool.getFromSettingsStack(nil)
if conn == nil {
conn = pool.pop(&pool.clean)
}
if conn == nil {
break
}
conn.Close()
pool.closedConn()
}

if pool.active.Load() == 0 {
close(pool.close)
}
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a potential race condition here. The active count could change between the Load() call and the close() call. Consider using a mutex or atomic operation to ensure thread safety when checking active connections and closing the channel.

Copilot uses AI. Check for mistakes.

return nil
}

func (pool *ConnPool[C]) reopen() {
Expand Down Expand Up @@ -305,7 +350,12 @@ func (pool *ConnPool[C]) reopen() {

// IsOpen returns whether the pool is open
func (pool *ConnPool[C]) IsOpen() bool {
return pool.close != nil
return pool.state.Load() == OPENED
}

// IsClosed returns whether the pool is closed
func (pool *ConnPool[C]) IsClosed() bool {
return pool.state.Load() == CLOSED
}

// Capacity returns the maximum amount of connections that this pool can maintain open
Expand Down Expand Up @@ -363,7 +413,7 @@ func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C],
if ctx.Err() != nil {
return nil, ErrCtxTimeout
}
if pool.capacity.Load() == 0 {
if pool.state.Load() != OPENED {
return nil, ErrConnPoolClosed
}
if setting == nil {
Expand All @@ -377,6 +427,16 @@ func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C],
func (pool *ConnPool[C]) put(conn *Pooled[C]) {
pool.borrowed.Add(-1)

// Close connection if pool is closed
if pool.IsClosed() {
if conn != nil {
conn.Close()
pool.closedConn()
}

return
}

if conn == nil {
var err error
// Using context.Background() is fine since MySQL connection already enforces
Expand Down Expand Up @@ -412,10 +472,24 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool {
connSetting := conn.Conn.Setting()
if connSetting == nil {
pool.clean.Push(conn)

// Close connection if pool is closed
if pool.IsClosed() {
conn = pool.pop(&pool.clean)
conn.Close()
pool.closedConn()
}
} else {
stack := connSetting.bucket & stackMask
pool.settings[stack].Push(conn)
pool.freshSettingsStack.Store(int64(stack))

// Close connection if pool is closed
if pool.IsClosed() {
conn = pool.pop(&pool.settings[stack])
conn.Close()
pool.closedConn()
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code pushes a connection to the clean stack and immediately pops it back. This is inefficient and could cause issues if other goroutines are accessing the stack. Consider closing the connection directly without adding it to the stack first.

Suggested change
pool.closedConn()
// Close connection if pool is closed
if pool.IsClosed() {
conn.Close()
pool.closedConn()
} else {
pool.clean.Push(conn)
}
} else {
stack := connSetting.bucket & stackMask
// Close connection if pool is closed
if pool.IsClosed() {
conn.Close()
pool.closedConn()
} else {
pool.settings[stack].Push(conn)
pool.freshSettingsStack.Store(int64(stack))

Copilot uses AI. Check for mistakes.
}
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous issue, this code pushes a connection to the settings stack and immediately pops it back. This creates unnecessary stack operations and potential race conditions. The connection should be closed directly without stack manipulation.

Suggested change
}
if pool.IsClosed() {
conn.Close()
pool.closedConn()
return false
}
if connSetting == nil {
pool.clean.Push(conn)
} else {
stack := connSetting.bucket & stackMask
pool.settings[stack].Push(conn)
pool.freshSettingsStack.Store(int64(stack))

Copilot uses AI. Check for mistakes.
}
return false
}
Expand Down Expand Up @@ -759,55 +833,133 @@ func (pool *ConnPool[C]) StatsJSON() map[string]any {
}
}

// RegisterStats registers this pool's metrics into a stats Exporter
func (pool *ConnPool[C]) RegisterStats(stats *servenv.Exporter, name string) {
if stats == nil || name == "" {
return
}
type StatsExporter[C Connection] struct {
// The Pool for which this exporter is exporting stats.
// It is an atomic pointer so that it can be updated safely.
// The pointer is nil if the pool has not been registered yet.
pool atomic.Pointer[ConnPool[C]]
}

pool.Name = name
func NewStatsExporter[C Connection](stats *servenv.Exporter, name string) *StatsExporter[C] {
se := &StatsExporter[C]{}

stats.NewGaugeFunc(name+"Capacity", "Tablet server conn pool capacity", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Capacity()
})
stats.NewGaugeFunc(name+"Available", "Tablet server conn pool available", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Available()
})
stats.NewGaugeFunc(name+"Active", "Tablet server conn pool active", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Active()
})
stats.NewGaugeFunc(name+"InUse", "Tablet server conn pool in use", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.InUse()
})
stats.NewGaugeFunc(name+"MaxCap", "Tablet server conn pool max cap", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

// the smartconnpool doesn't have a maximum capacity
return pool.Capacity()
})
stats.NewCounterFunc(name+"WaitCount", "Tablet server conn pool wait count", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.WaitCount()
})
stats.NewCounterDurationFunc(name+"WaitTime", "Tablet server wait time", func() time.Duration {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.WaitTime()
})
stats.NewGaugeDurationFunc(name+"IdleTimeout", "Tablet server idle timeout", func() time.Duration {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.IdleTimeout()
})
stats.NewCounterFunc(name+"IdleClosed", "Tablet server conn pool idle closed", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.IdleClosed()
})
stats.NewCounterFunc(name+"MaxLifetimeClosed", "Tablet server conn pool refresh closed", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.MaxLifetimeClosed()
})
stats.NewCounterFunc(name+"Get", "Tablet server conn pool get count", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.GetCount()
})
stats.NewCounterFunc(name+"GetSetting", "Tablet server conn pool get with setting count", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.GetSettingCount()
})
stats.NewCounterFunc(name+"DiffSetting", "Number of times pool applied different setting", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.DiffSettingCount()
})
stats.NewCounterFunc(name+"ResetSetting", "Number of times pool reset the setting", func() int64 {
pool := se.pool.Load()
if pool == nil {
return 0
}

return pool.Metrics.ResetSettingCount()
})

return se
}

func (se *StatsExporter[C]) SetPool(pool *ConnPool[C]) {
se.pool.Store(pool)
}
25 changes: 9 additions & 16 deletions go/pools/smartconnpool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,23 +566,17 @@ func TestUserClosing(t *testing.T) {
r.Recycle()
}

ch := make(chan error)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
p.CloseWithContext(ctx)

err := p.CloseWithContext(ctx)
ch <- err
close(ch)
}()
require.Equal(t, p.Active(), int64(1))
require.Equal(t, p.Capacity(), int64(0))
require.Equal(t, p.IsOpen(), false)

select {
case <-time.After(5 * time.Second):
t.Fatalf("Pool did not shutdown after 5s")
case err := <-ch:
require.Error(t, err)
t.Logf("Shutdown error: %v", err)
}
resources[4].Recycle()

require.Equal(t, p.Active(), int64(0))

p.workers.Wait()
}

func TestConnReopen(t *testing.T) {
Expand Down Expand Up @@ -621,7 +615,6 @@ func TestConnReopen(t *testing.T) {
time.Sleep(300 * time.Millisecond)
// no active connection should be left.
assert.Zero(t, p.Active())

}

func TestIdleTimeout(t *testing.T) {
Expand Down
Loading
Loading