diff --git a/telemetry/circuitbreaker.go b/telemetry/circuitbreaker.go index a6fbfc8..4b8f3f2 100644 --- a/telemetry/circuitbreaker.go +++ b/telemetry/circuitbreaker.go @@ -332,9 +332,15 @@ func getCircuitBreakerManager() *circuitBreakerManager { // getCircuitBreaker gets or creates a circuit breaker for the host. // Thread-safe for concurrent access. +// +// The host is normalized (scheme stripped, lowercased, trailing slash trimmed) +// so DSN variants ("example.com", "example.com/", "https://example.com") +// share a single breaker per logical host instead of fragmenting trip state. func (m *circuitBreakerManager) getCircuitBreaker(host string) *circuitBreaker { + key := normalizeHostKey(host) + m.mu.RLock() - cb, exists := m.breakers[host] + cb, exists := m.breakers[key] m.mu.RUnlock() if exists { @@ -345,11 +351,11 @@ func (m *circuitBreakerManager) getCircuitBreaker(host string) *circuitBreaker { defer m.mu.Unlock() // Double-check after acquiring write lock - if cb, exists = m.breakers[host]; exists { + if cb, exists = m.breakers[key]; exists { return cb } cb = newCircuitBreaker(defaultCircuitBreakerConfig()) - m.breakers[host] = cb + m.breakers[key] = cb return cb } diff --git a/telemetry/circuitbreaker_test.go b/telemetry/circuitbreaker_test.go index c9d7a88..fc2b309 100644 --- a/telemetry/circuitbreaker_test.go +++ b/telemetry/circuitbreaker_test.go @@ -381,6 +381,30 @@ func TestCircuitBreakerManager_PerHostIsolation(t *testing.T) { } } +func TestCircuitBreakerManager_HostVariantsShareBreaker(t *testing.T) { + mgr := &circuitBreakerManager{breakers: make(map[string]*circuitBreaker)} + + canonical := "normalize.example.com" + variants := []string{ + canonical, + canonical + "/", + "https://" + canonical, + "HTTPS://Normalize.Example.com/", + "http://" + canonical + "/", + } + + first := mgr.getCircuitBreaker(variants[0]) + for _, v := range variants[1:] { + if got := mgr.getCircuitBreaker(v); got != first { + t.Errorf("variant %q returned a different breaker than %q", v, variants[0]) + } + } + + if len(mgr.breakers) != 1 { + t.Errorf("expected 1 breaker for all host variants, got %d", len(mgr.breakers)) + } +} + func TestCircuitBreakerManager_ConcurrentAccess(t *testing.T) { mgr := getCircuitBreakerManager() var wg sync.WaitGroup diff --git a/telemetry/manager.go b/telemetry/manager.go index 19a72c2..bba7d1e 100644 --- a/telemetry/manager.go +++ b/telemetry/manager.go @@ -2,11 +2,24 @@ package telemetry import ( "net/http" + "strings" "sync" "github.com/databricks/databricks-sql-go/logger" ) +// normalizeHostKey returns a canonical lookup key for host-based registries +// (telemetry clients, circuit breakers). It lowercases, trims whitespace, +// strips http/https scheme, and trims trailing slashes so trivial variations +// in DSN input ("example.com", "example.com/", "https://example.com") share +// state instead of fragmenting into independent breakers/clients. +func normalizeHostKey(host string) string { + h := strings.ToLower(strings.TrimSpace(host)) + h = strings.TrimPrefix(h, "https://") + h = strings.TrimPrefix(h, "http://") + return strings.TrimRight(h, "/") +} + // clientManager manages one telemetry client per host. // // Design: @@ -53,10 +66,12 @@ func getClientManager() *clientManager { // per-host singleton consolidates telemetry across connections to keep // the request rate low. func (m *clientManager) getOrCreateClient(host string, driverVersion string, userAgent string, httpClient *http.Client, cfg *Config) *telemetryClient { + key := normalizeHostKey(host) + m.mu.Lock() defer m.mu.Unlock() - holder, exists := m.clients[host] + holder, exists := m.clients[key] if !exists { client := newTelemetryClient(host, driverVersion, userAgent, httpClient, cfg) if err := client.start(); err != nil { @@ -67,7 +82,7 @@ func (m *clientManager) getOrCreateClient(host string, driverVersion string, use holder = &clientHolder{ client: client, } - m.clients[host] = holder + m.clients[key] = holder } holder.refCount++ return holder.client @@ -76,8 +91,10 @@ func (m *clientManager) getOrCreateClient(host string, driverVersion string, use // releaseClient decrements reference count for the host. // Closes and removes client when ref count reaches zero. func (m *clientManager) releaseClient(host string) error { + key := normalizeHostKey(host) + m.mu.Lock() - holder, exists := m.clients[host] + holder, exists := m.clients[key] if !exists { m.mu.Unlock() return nil @@ -89,7 +106,7 @@ func (m *clientManager) releaseClient(host string) error { logger.Logger.Debug().Str("host", host).Int("refCount", holder.refCount).Msg("telemetry client refCount became negative") } if holder.refCount <= 0 { - delete(m.clients, host) + delete(m.clients, key) m.mu.Unlock() return holder.client.close() // Close and flush } diff --git a/telemetry/manager_test.go b/telemetry/manager_test.go index 3567d3b..e8f346c 100644 --- a/telemetry/manager_test.go +++ b/telemetry/manager_test.go @@ -401,6 +401,65 @@ func TestClientManager_ShutdownWithActiveRefs(t *testing.T) { } } +func TestNormalizeHostKey(t *testing.T) { + cases := []struct { + in, want string + }{ + {"example.databricks.com", "example.databricks.com"}, + {"example.databricks.com/", "example.databricks.com"}, + {"example.databricks.com//", "example.databricks.com"}, + {"https://example.databricks.com", "example.databricks.com"}, + {"http://example.databricks.com", "example.databricks.com"}, + {"HTTPS://Example.Databricks.com/", "example.databricks.com"}, + {" example.databricks.com ", "example.databricks.com"}, + {"", ""}, + } + for _, c := range cases { + if got := normalizeHostKey(c.in); got != c.want { + t.Errorf("normalizeHostKey(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestClientManager_HostVariantsShareClient(t *testing.T) { + manager := &clientManager{ + clients: make(map[string]*clientHolder), + } + httpClient := &http.Client{} + cfg := DefaultConfig() + + variants := []string{ + "example.databricks.com", + "example.databricks.com/", + "https://example.databricks.com", + "HTTPS://Example.Databricks.com/", + } + + first := manager.getOrCreateClient(variants[0], "v", "ua", httpClient, cfg) + if first == nil { + t.Fatal("expected client") + } + for _, v := range variants[1:] { + got := manager.getOrCreateClient(v, "v", "ua", httpClient, cfg) + if got != first { + t.Errorf("variant %q got a different client instance — should share with %q", v, variants[0]) + } + } + if len(manager.clients) != 1 { + t.Errorf("expected 1 holder for all variants, got %d", len(manager.clients)) + } + + // Release using a different variant still finds the holder. + for range variants { + if err := manager.releaseClient("https://example.databricks.com/"); err != nil { + t.Fatalf("release: %v", err) + } + } + if len(manager.clients) != 0 { + t.Errorf("expected holder removed after all releases, got %d", len(manager.clients)) + } +} + func TestClientManager_ShutdownEmptyManager(t *testing.T) { manager := &clientManager{ clients: make(map[string]*clientHolder),