Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# v1.3.1
## Fixes
- Reuse OAuth2 token source to prevent unnecessary token fetches for each request.

# v1.3.0

## Features
Expand Down
16 changes: 14 additions & 2 deletions auth_providers/auth_oauth.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Keyfactor
// Copyright 2026 Keyfactor
Comment thread
irby marked this conversation as resolved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"time"

"golang.org/x/oauth2"
Expand Down Expand Up @@ -115,6 +116,10 @@ type CommandConfigOauth struct {

// TokenURL is the token URL for OAuth authentication
TokenURL string `json:"token_url,omitempty" yaml:"token_url,omitempty"`

// unexported: lazily initialized, shared across GetHttpClient() calls
tokenSource oauth2.TokenSource
tsMu sync.Mutex
}

// NewOAuthAuthenticatorBuilder creates a new CommandConfigOauth instance.
Expand Down Expand Up @@ -222,7 +227,14 @@ func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) {
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: baseTransport})
tokenSource := config.TokenSource(ctx)

// Lazily initialize the token source and cache it
b.tsMu.Lock()
if b.tokenSource == nil {
b.tokenSource = config.TokenSource(ctx)
}
tokenSource := b.tokenSource
b.tsMu.Unlock()

client = http.Client{
Transport: &oauth2Transport{
Expand Down
49 changes: 48 additions & 1 deletion auth_providers/auth_oauth_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Keyfactor
// Copyright 2026 Keyfactor
Comment thread
irby marked this conversation as resolved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,11 @@ package auth_providers_test

import (
"crypto/tls"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -568,3 +570,48 @@ func DownloadCertificate(input string, outputPath string) error {
fmt.Printf("Certificate chain saved to: %s\n", outputFile)
return nil
}

func TestCommandConfigOauth_TokenSourceIsReused(t *testing.T) {
tokenRequestCount := 0

// Fake IdP token endpoint
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenRequestCount++
w.Header().Set("Content-Type", "application/json")
Comment thread
irby marked this conversation as resolved.
Outdated
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": "shared-test-token",
"token_type": "Bearer",
"expires_in": 3600,
})
}))
defer tokenServer.Close()

// Fake API endpoint (just needs to accept requests)
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer apiServer.Close()

config := &auth_providers.CommandConfigOauth{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
TokenURL: tokenServer.URL + "/token",
}

// Get multiple clients from the same config
const numClients = 3
for i := 0; i < numClients; i++ {
client, err := config.GetHttpClient()
if err != nil {
t.Fatalf("GetHttpClient() call %d failed: %v", i+1, err)
}
_, err = client.Get(apiServer.URL)
if err != nil {
t.Fatalf("request %d failed: %v", i+1, err)
}
Comment thread
irby marked this conversation as resolved.
Outdated
}

if tokenRequestCount != 1 {
t.Errorf("expected token endpoint to be called once, got %d — token source is not being reused across GetHttpClient() calls", tokenRequestCount)
}
}
Loading