Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/src/auth_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func init() {
}

func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
if err := cfg.requireCIAccessToken(); err != nil {
return "", err
}

if cfg.accessToken != "" {
return cfg.accessToken, nil
}
Expand Down
22 changes: 22 additions & 0 deletions cmd/src/auth_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ func TestResolveAuthToken(t *testing.T) {
}
})

t.Run("requires access token in CI", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()

loadCalled := false
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
loadCalled = true
return nil, nil
}

_, err := resolveAuthToken(context.Background(), &config{
inCI: true,
endpointURL: mustParseURL(t, "https://example.com"),
})
if err != errCIAccessTokenRequired {
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
}
if loadCalled {
t.Fatal("expected OAuth token loader not to be called")
}
})

t.Run("uses stored oauth token", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()
Expand Down
4 changes: 4 additions & 0 deletions cmd/src/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ const (
)

func loginCmd(ctx context.Context, p loginParams) error {
if err := p.cfg.requireCIAccessToken(); err != nil {
return err
}

if p.cfg.configFilePath != "" {
fmt.Fprintln(p.out)
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath)
Expand Down
11 changes: 11 additions & 0 deletions cmd/src/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ func TestLogin(t *testing.T) {
}
})

t.Run("CI requires access token", func(t *testing.T) {
u := &url.URL{Scheme: "https", Host: "example.com"}
out, err := check(t, &config{endpointURL: u, inCI: true}, u)
if err != errCIAccessTokenRequired {
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
}
if out != "" {
t.Fatalf("output = %q, want empty output", out)
}
})

t.Run("warning when using config file", func(t *testing.T) {
endpoint := &url.URL{Scheme: "https", Host: "example.com"}
out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint)
Expand Down
38 changes: 26 additions & 12 deletions cmd/src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ var (

errConfigMerge = errors.New("when using a configuration file, zero or all environment variables must be set")
errConfigAuthorizationConflict = errors.New("when passing an 'Authorization' additional headers, SRC_ACCESS_TOKEN must never be set")
errCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set in CI")
errCIAccessTokenRequired = errors.New("CI is true and SRC_ACCESS_TOKEN is not set or empty. When running in CI OAuth tokens cannot be used, only SRC_ACCESS_TOKEN. Either set CI=false or define a SRC_ACCESS_TOKEN")
)

// commands contains all registered subcommands.
Expand Down Expand Up @@ -137,6 +137,7 @@ type config struct {
proxyPath string
configFilePath string
endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig
inCI bool
}

// configFromFile holds the config as read from the config file,
Expand All @@ -162,16 +163,32 @@ func (c *config) AuthMode() AuthMode {
return AuthModeOAuth
}

func (c *config) InCI() bool {
return c.inCI
}

func (c *config) requireCIAccessToken() error {
Comment thread
burmudar marked this conversation as resolved.
// In CI we typically do not have access to the keyring and the machine is also typically headless
// we therefore require SRC_ACCESS_TOKEN to be set when in CI.
// If someone really wants to run with OAuth in CI they can temporarily do CI=false
if c.InCI() && c.AuthMode() != AuthModeAccessToken {
return errCIAccessTokenRequired
Comment thread
burmudar marked this conversation as resolved.
}

return nil
}

// apiClient returns an api.Client built from the configuration.
func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client {
opts := api.ClientOpts{
EndpointURL: c.endpointURL,
AccessToken: c.accessToken,
AdditionalHeaders: c.additionalHeaders,
Flags: flags,
Out: out,
ProxyURL: c.proxyURL,
ProxyPath: c.proxyPath,
EndpointURL: c.endpointURL,
AccessToken: c.accessToken,
AdditionalHeaders: c.additionalHeaders,
Flags: flags,
Out: out,
ProxyURL: c.proxyURL,
ProxyPath: c.proxyPath,
RequireAccessTokenInCI: c.InCI(),
}

// Only use OAuth if we do not have SRC_ACCESS_TOKEN set
Expand Down Expand Up @@ -205,6 +222,7 @@ func readConfig() (*config, error) {

var cfgFromFile configFromFile
var cfg config
cfg.inCI = isCI()
var endpointStr string
var proxyStr string
if err == nil {
Expand Down Expand Up @@ -312,10 +330,6 @@ func readConfig() (*config, error) {
return nil, errConfigAuthorizationConflict
}

if isCI() && cfg.accessToken == "" {
return nil, errCIAccessTokenRequired
}

return &cfg, nil
}

Expand Down
47 changes: 44 additions & 3 deletions cmd/src/main_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package main

import (
"context"
"encoding/json"
"errors"

Check failure on line 6 in cmd/src/main_test.go

View workflow job for this annotation

GitHub Actions / go-lint

import 'errors' is not allowed from list 'main': Use github.com/sourcegraph/sourcegraph/lib/errors instead (depguard)
"io"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -325,9 +328,13 @@
wantErr: errConfigAuthorizationConflict.Error(),
},
{
name: "CI requires access token",
envCI: "1",
wantErr: errCIAccessTokenRequired.Error(),
name: "CI does not require access token during config read",
envCI: "1",
want: &config{
endpointURL: &url.URL{Scheme: "https", Host: "sourcegraph.com"},
additionalHeaders: map[string]string{},
inCI: true,
},
},
{
name: "CI allows access token from config file",
Expand All @@ -340,6 +347,7 @@
endpointURL: &url.URL{Scheme: "https", Host: "example.com"},
accessToken: "deadbeef",
additionalHeaders: map[string]string{},
inCI: true,
},
},
}
Expand Down Expand Up @@ -422,3 +430,36 @@
}
})
}

func TestConfigAPIClientCIAccessTokenGate(t *testing.T) {
endpointURL := &url.URL{Scheme: "https", Host: "example.com"}

t.Run("requires access token in CI", func(t *testing.T) {
client := (&config{endpointURL: endpointURL, inCI: true}).apiClient(nil, io.Discard)

_, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
if !errors.Is(err, api.ErrCIAccessTokenRequired) {
t.Fatalf("NewHTTPRequest() error = %v, want %v", err, api.ErrCIAccessTokenRequired)
}
})

t.Run("allows access token in CI", func(t *testing.T) {
client := (&config{endpointURL: endpointURL, inCI: true, accessToken: "abc"}).apiClient(nil, io.Discard)

req, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
if err != nil {
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
}
if got := req.Header.Get("Authorization"); got != "token abc" {
t.Fatalf("Authorization header = %q, want %q", got, "token abc")
}
})

t.Run("allows oauth mode outside CI", func(t *testing.T) {
client := (&config{endpointURL: endpointURL}).apiClient(nil, io.Discard)

if _, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil); err != nil {
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
}
})
}
7 changes: 1 addition & 6 deletions cmd/src/search_jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,7 @@ func parseColumns(columnsFlag string) []string {

// createSearchJobsClient creates a reusable API client for search jobs commands
func createSearchJobsClient(out *flag.FlagSet, apiFlags *api.Flags) api.Client {
return api.NewClient(api.ClientOpts{
EndpointURL: cfg.endpointURL,
AccessToken: cfg.accessToken,
Out: out.Output(),
Flags: apiFlags,
})
return cfg.apiClient(apiFlags, out.Output())
}

// parseSearchJobsArgs parses command arguments with the provided flag set
Expand Down
41 changes: 33 additions & 8 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"context"
"crypto/tls"
"encoding/json"
"errors"

Check failure on line 9 in internal/api/api.go

View workflow job for this annotation

GitHub Actions / go-lint

import 'errors' is not allowed from list 'main': Use github.com/sourcegraph/sourcegraph/lib/errors instead (depguard)
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -71,9 +72,10 @@

// ClientOpts encapsulates the options given to NewClient.
type ClientOpts struct {
EndpointURL *url.URL
AccessToken string
AdditionalHeaders map[string]string
EndpointURL *url.URL
AccessToken string
AdditionalHeaders map[string]string
RequireAccessTokenInCI bool

// Flags are the standard API client flags provided by NewFlags. If nil,
// default values will be used.
Expand All @@ -89,6 +91,9 @@
OAuthToken *oauth.Token
}

// ErrCIAccessTokenRequired indicates SRC_ACCESS_TOKEN must be set when CI=true.
var ErrCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set when CI=true")

func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
var transport http.RoundTripper
{
Expand All @@ -109,6 +114,9 @@
transport = tp
}

// not we do not fail here if requireAccessToken is true, because that would
// mean returning an error on construction which we want to avoid for now
// TODO(burmudar): allow returning of an error upon client construction
if opts.AccessToken == "" && opts.OAuthToken != nil {
transport = oauth.NewTransport(transport, opts.OAuthToken)
}
Expand All @@ -135,15 +143,24 @@

return &client{
opts: ClientOpts{
EndpointURL: opts.EndpointURL,
AccessToken: opts.AccessToken,
AdditionalHeaders: opts.AdditionalHeaders,
Flags: flags,
Out: opts.Out,
EndpointURL: opts.EndpointURL,
AccessToken: opts.AccessToken,
AdditionalHeaders: opts.AdditionalHeaders,
RequireAccessTokenInCI: opts.RequireAccessTokenInCI,
Flags: flags,
Out: opts.Out,
},
httpClient: httpClient,
}
}

func (c *client) checkIfCIAccessTokenRequired() error {
if c.opts.RequireAccessTokenInCI && c.opts.AccessToken == "" {
return ErrCIAccessTokenRequired
}

return nil
}
func (c *client) NewQuery(query string) Request {
return c.NewRequest(query, nil)
}
Expand All @@ -170,6 +187,10 @@
}

func (c *client) createHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) {
if err := c.checkIfCIAccessTokenRequired(); err != nil {
return nil, err
}

// Can't use c.opts.EndpointURL.JoinPath(p) here because `p` could contain a query string
req, err := http.NewRequestWithContext(ctx, method, c.opts.EndpointURL.String()+"/"+p, body)
if err != nil {
Expand Down Expand Up @@ -199,6 +220,10 @@
}

func (r *request) do(ctx context.Context, result any) (bool, error) {
if err := r.client.checkIfCIAccessTokenRequired(); err != nil {
return false, err
}

if *r.client.opts.Flags.getCurl {
curl, err := r.curlCmd()
if err != nil {
Expand Down
Loading