Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions telemetry/circuitbreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,15 @@ func getCircuitBreakerManager() *circuitBreakerManager {

// getCircuitBreaker gets or creates a circuit breaker for the host.
// Thread-safe for concurrent access.
//
// The host is normalized (scheme stripped, lowercased, trailing slash trimmed)
// so DSN variants ("example.com", "example.com/", "https://example.com")
// share a single breaker per logical host instead of fragmenting trip state.
func (m *circuitBreakerManager) getCircuitBreaker(host string) *circuitBreaker {
key := normalizeHostKey(host)

m.mu.RLock()
cb, exists := m.breakers[host]
cb, exists := m.breakers[key]
m.mu.RUnlock()

if exists {
Expand All @@ -345,11 +351,11 @@ func (m *circuitBreakerManager) getCircuitBreaker(host string) *circuitBreaker {
defer m.mu.Unlock()

// Double-check after acquiring write lock
if cb, exists = m.breakers[host]; exists {
if cb, exists = m.breakers[key]; exists {
return cb
}

cb = newCircuitBreaker(defaultCircuitBreakerConfig())
m.breakers[host] = cb
m.breakers[key] = cb
return cb
}
24 changes: 24 additions & 0 deletions telemetry/circuitbreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,30 @@ func TestCircuitBreakerManager_PerHostIsolation(t *testing.T) {
}
}

func TestCircuitBreakerManager_HostVariantsShareBreaker(t *testing.T) {
mgr := &circuitBreakerManager{breakers: make(map[string]*circuitBreaker)}

canonical := "normalize.example.com"
variants := []string{
canonical,
canonical + "/",
"https://" + canonical,
"HTTPS://Normalize.Example.com/",
"http://" + canonical + "/",
}

first := mgr.getCircuitBreaker(variants[0])
for _, v := range variants[1:] {
if got := mgr.getCircuitBreaker(v); got != first {
t.Errorf("variant %q returned a different breaker than %q", v, variants[0])
}
}

if len(mgr.breakers) != 1 {
t.Errorf("expected 1 breaker for all host variants, got %d", len(mgr.breakers))
}
}

func TestCircuitBreakerManager_ConcurrentAccess(t *testing.T) {
mgr := getCircuitBreakerManager()
var wg sync.WaitGroup
Expand Down
25 changes: 21 additions & 4 deletions telemetry/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@ package telemetry

import (
"net/http"
"strings"
"sync"

"github.com/databricks/databricks-sql-go/logger"
)

// normalizeHostKey returns a canonical lookup key for host-based registries
// (telemetry clients, circuit breakers). It lowercases, trims whitespace,
// strips http/https scheme, and trims trailing slashes so trivial variations
// in DSN input ("example.com", "example.com/", "https://example.com") share
// state instead of fragmenting into independent breakers/clients.
func normalizeHostKey(host string) string {
h := strings.ToLower(strings.TrimSpace(host))
h = strings.TrimPrefix(h, "https://")
h = strings.TrimPrefix(h, "http://")
return strings.TrimRight(h, "/")
}

// clientManager manages one telemetry client per host.
//
// Design:
Expand Down Expand Up @@ -53,10 +66,12 @@ func getClientManager() *clientManager {
// per-host singleton consolidates telemetry across connections to keep
// the request rate low.
func (m *clientManager) getOrCreateClient(host string, driverVersion string, userAgent string, httpClient *http.Client, cfg *Config) *telemetryClient {
key := normalizeHostKey(host)

m.mu.Lock()
defer m.mu.Unlock()

holder, exists := m.clients[host]
holder, exists := m.clients[key]
if !exists {
client := newTelemetryClient(host, driverVersion, userAgent, httpClient, cfg)
if err := client.start(); err != nil {
Expand All @@ -67,7 +82,7 @@ func (m *clientManager) getOrCreateClient(host string, driverVersion string, use
holder = &clientHolder{
client: client,
}
m.clients[host] = holder
m.clients[key] = holder
}
holder.refCount++
return holder.client
Expand All @@ -76,8 +91,10 @@ func (m *clientManager) getOrCreateClient(host string, driverVersion string, use
// releaseClient decrements reference count for the host.
// Closes and removes client when ref count reaches zero.
func (m *clientManager) releaseClient(host string) error {
key := normalizeHostKey(host)

m.mu.Lock()
holder, exists := m.clients[host]
holder, exists := m.clients[key]
if !exists {
m.mu.Unlock()
return nil
Expand All @@ -89,7 +106,7 @@ func (m *clientManager) releaseClient(host string) error {
logger.Logger.Debug().Str("host", host).Int("refCount", holder.refCount).Msg("telemetry client refCount became negative")
}
if holder.refCount <= 0 {
delete(m.clients, host)
delete(m.clients, key)
m.mu.Unlock()
return holder.client.close() // Close and flush
}
Expand Down
59 changes: 59 additions & 0 deletions telemetry/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,65 @@ func TestClientManager_ShutdownWithActiveRefs(t *testing.T) {
}
}

func TestNormalizeHostKey(t *testing.T) {
cases := []struct {
in, want string
}{
{"example.databricks.com", "example.databricks.com"},
{"example.databricks.com/", "example.databricks.com"},
{"example.databricks.com//", "example.databricks.com"},
{"https://example.databricks.com", "example.databricks.com"},
{"http://example.databricks.com", "example.databricks.com"},
{"HTTPS://Example.Databricks.com/", "example.databricks.com"},
{" example.databricks.com ", "example.databricks.com"},
{"", ""},
}
for _, c := range cases {
if got := normalizeHostKey(c.in); got != c.want {
t.Errorf("normalizeHostKey(%q) = %q, want %q", c.in, got, c.want)
}
}
}

func TestClientManager_HostVariantsShareClient(t *testing.T) {
manager := &clientManager{
clients: make(map[string]*clientHolder),
}
httpClient := &http.Client{}
cfg := DefaultConfig()

variants := []string{
"example.databricks.com",
"example.databricks.com/",
"https://example.databricks.com",
"HTTPS://Example.Databricks.com/",
}

first := manager.getOrCreateClient(variants[0], "v", "ua", httpClient, cfg)
if first == nil {
t.Fatal("expected client")
}
for _, v := range variants[1:] {
got := manager.getOrCreateClient(v, "v", "ua", httpClient, cfg)
if got != first {
t.Errorf("variant %q got a different client instance — should share with %q", v, variants[0])
}
}
if len(manager.clients) != 1 {
t.Errorf("expected 1 holder for all variants, got %d", len(manager.clients))
}

// Release using a different variant still finds the holder.
for range variants {
if err := manager.releaseClient("https://example.databricks.com/"); err != nil {
t.Fatalf("release: %v", err)
}
}
if len(manager.clients) != 0 {
t.Errorf("expected holder removed after all releases, got %d", len(manager.clients))
}
}

func TestClientManager_ShutdownEmptyManager(t *testing.T) {
manager := &clientManager{
clients: make(map[string]*clientHolder),
Expand Down
Loading