Skip to content

Commit 947ae61

Browse files
committed
feat(McpHub): handle OAuth refresh mid tool call
1 parent e580b77 commit 947ae61

5 files changed

Lines changed: 263 additions & 28 deletions

File tree

src/services/mcp/McpHub.ts

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export type ConnectedMcpConnection = {
4949
server: McpServer
5050
client: Client
5151
transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport
52+
authProvider?: McpOAuthClientProvider
5253
}
5354

5455
export type DisconnectedMcpConnection = {
@@ -166,6 +167,7 @@ export class McpHub {
166167
private sanitizedNameRegistry: Map<string, string> = new Map()
167168
private initializationPromise: Promise<void>
168169
private secretStorage?: SecretStorageService
170+
private reauthPromises: Map<string, Promise<void>> = new Map()
169171

170172
constructor(provider: ClineProvider) {
171173
this.providerRef = new WeakRef(provider)
@@ -822,6 +824,31 @@ export class McpHub {
822824
console.error(`Transport error for "${name}" (streamable-http):`, error)
823825
const connection = this.findConnection(name, source)
824826
if (connection) {
827+
if (error instanceof UnauthorizedError && authProvider) {
828+
// Mid-session re-auth triggered by a tool call (401)
829+
connection.server.status = "connecting"
830+
831+
const reauthKey = `${name}:${source}`
832+
let reauthPromise = this.reauthPromises.get(reauthKey)
833+
if (!reauthPromise) {
834+
reauthPromise = this._completeOAuthFlow(
835+
authProvider,
836+
transport as StreamableHTTPClientTransport,
837+
connection as ConnectedMcpConnection,
838+
name,
839+
source,
840+
)
841+
.catch((err) => {
842+
console.error(`OAuth flow failed for "${name}":`, err)
843+
})
844+
.finally(() => {
845+
this.reauthPromises.delete(reauthKey)
846+
})
847+
this.reauthPromises.set(reauthKey, reauthPromise)
848+
}
849+
return
850+
}
851+
825852
connection.server.status = "disconnected"
826853
this.appendErrorMessage(connection, error instanceof Error ? error.message : `${error}`)
827854
}
@@ -905,6 +932,7 @@ export class McpHub {
905932
},
906933
client,
907934
transport,
935+
authProvider: streamableHttpAuthProvider,
908936
}
909937
this.connections.push(connection)
910938

@@ -935,6 +963,7 @@ export class McpHub {
935963
}
936964

937965
// Successful connection — close callback server if it was started.
966+
// We keep the authProvider on the connection so it can handle mid-session 401s.
938967
await streamableHttpAuthProvider?.close()
939968

940969
connection.server.status = "connected"
@@ -1203,9 +1232,10 @@ export class McpHub {
12031232
if (connection.type === "connected") {
12041233
await connection.transport.close()
12051234
await connection.client.close()
1235+
await connection.authProvider?.close()
12061236
}
12071237
} catch (error) {
1208-
console.error(`Failed to close transport for ${name}:`, error)
1238+
console.error(`Failed to close transport or auth provider for ${name}:`, error)
12091239
}
12101240
}
12111241

@@ -1876,19 +1906,66 @@ export class McpHub {
18761906
timeout = 60 * 1000
18771907
}
18781908

1879-
return await connection.client.request(
1880-
{
1881-
method: "tools/call",
1882-
params: {
1883-
name: toolName,
1884-
arguments: toolArguments,
1909+
try {
1910+
return await connection.client.request(
1911+
{
1912+
method: "tools/call",
1913+
params: {
1914+
name: toolName,
1915+
arguments: toolArguments,
1916+
},
18851917
},
1886-
},
1887-
CallToolResultSchema,
1888-
{
1889-
timeout,
1890-
},
1891-
)
1918+
CallToolResultSchema,
1919+
{
1920+
timeout,
1921+
},
1922+
)
1923+
} catch (error) {
1924+
if (error instanceof UnauthorizedError && connection.authProvider) {
1925+
// Mid-session re-auth triggered by a tool call (401)
1926+
connection.server.status = "connecting"
1927+
1928+
const reauthKey = `${serverName}:${source || connection.server.source || "global"}`
1929+
let reauthPromise = this.reauthPromises.get(reauthKey)
1930+
1931+
if (!reauthPromise) {
1932+
reauthPromise = this._completeOAuthFlow(
1933+
connection.authProvider,
1934+
connection.transport as StreamableHTTPClientTransport,
1935+
connection,
1936+
serverName,
1937+
source || connection.server.source || "global",
1938+
).finally(() => {
1939+
this.reauthPromises.delete(reauthKey)
1940+
})
1941+
this.reauthPromises.set(reauthKey, reauthPromise)
1942+
}
1943+
1944+
await reauthPromise
1945+
1946+
// After re-auth completes, the connection has been replaced.
1947+
// We need to find the new connection and retry the tool call.
1948+
const newConnection = this.findConnection(serverName, source)
1949+
if (!newConnection || newConnection.type !== "connected") {
1950+
throw new Error(`Failed to reconnect to server ${serverName} after OAuth`)
1951+
}
1952+
1953+
return await newConnection.client.request(
1954+
{
1955+
method: "tools/call",
1956+
params: {
1957+
name: toolName,
1958+
arguments: toolArguments,
1959+
},
1960+
},
1961+
CallToolResultSchema,
1962+
{
1963+
timeout,
1964+
},
1965+
)
1966+
}
1967+
throw error
1968+
}
18921969
}
18931970

18941971
/**

src/services/mcp/McpOAuthClientProvider.ts

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
3636
// when the redirect URI port changes between sessions.
3737
private _clientInfo?: OAuthClientInformationFull
3838
private _closed = false
39+
private _refreshPromise: Promise<OAuthTokens> | null = null
3940

4041
private constructor(
4142
private readonly _serverUrl: string,
4243
private readonly _secretStorage: SecretStorageService,
43-
private readonly _server: http.Server,
44-
private readonly _port: number,
45-
private readonly _authCodePromise: Promise<string>,
44+
private _server: http.Server | null,
45+
private _port: number,
46+
private _authCodePromise: Promise<string> | null,
4647
private readonly _tokenEndpointAuthMethod: string,
4748
private readonly _grantTypes: string[],
4849
private readonly _scopes: string[],
@@ -120,6 +121,20 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
120121
return `http://localhost:${this._port}/callback`
121122
}
122123

124+
private async _ensureCallbackServer(): Promise<void> {
125+
if (this._server && !this._closed) return
126+
127+
this._closed = false
128+
const { server, port, result } = await startCallbackServer(this._port, this._state)
129+
this._server = server
130+
this._port = port
131+
this._authCodePromise = result.then((r) => {
132+
if (r.error) throw new Error(`OAuth authorization failed: ${r.error}`)
133+
if (!r.code) throw new Error("No authorization code received in callback")
134+
return r.code
135+
})
136+
}
137+
123138
state(): string {
124139
return this._state
125140
}
@@ -192,10 +207,33 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
192207
async tokens(): Promise<OAuthTokens | undefined> {
193208
const data = await this._secretStorage.getOAuthData(this._serverUrl)
194209
if (!data) return undefined
195-
// Return undefined 5 minutes before expiry so the SDK triggers re-auth
196-
// before the server actually rejects requests.
197-
if (Date.now() >= data.expires_at - 5 * 60 * 1000) return undefined
198-
return data.tokens
210+
211+
// If the access token is still valid (with 5m buffer), return it.
212+
if (Date.now() < data.expires_at - 5 * 60 * 1000) {
213+
return data.tokens
214+
}
215+
216+
// Access token is expired or near expiry. Try to refresh if we have a refresh token.
217+
if (data.tokens.refresh_token) {
218+
if (this._refreshPromise) {
219+
return this._refreshPromise
220+
}
221+
222+
this._refreshPromise = this.refreshAccessToken(data.tokens.refresh_token).finally(() => {
223+
this._refreshPromise = null
224+
})
225+
226+
try {
227+
return await this._refreshPromise
228+
} catch (error) {
229+
console.error(`Failed to refresh MCP OAuth token for ${this._serverUrl}:`, error)
230+
// Clear stale tokens on refresh failure so we don't keep retrying a dead refresh token
231+
await this._secretStorage.deleteOAuthData(this._serverUrl)
232+
// Fall through to return undefined, which triggers full re-auth
233+
}
234+
}
235+
236+
return undefined
199237
}
200238

201239
async saveTokens(tokens: OAuthTokens): Promise<void> {
@@ -209,6 +247,10 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
209247
}
210248

211249
async redirectToAuthorization(authorizationUrl: URL): Promise<void> {
250+
// Ensure the callback server is running before opening the browser.
251+
// This handles mid-session re-auth where the initial server was closed.
252+
await this._ensureCallbackServer()
253+
212254
// Workaround for SDK metadata discovery bug (see utils/oauth.ts for issue links).
213255
// The SDK's discoverOAuthMetadata() builds a wrong well-known URL for issuers
214256
// with path components, causing it to fall back to a default "/authorize" path.
@@ -267,8 +309,11 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
267309
* browser flow and the local callback server receives the redirect.
268310
* Rejects on error or 5-minute timeout.
269311
*/
270-
waitForAuthCode(): Promise<string> {
271-
return this._authCodePromise
312+
async waitForAuthCode(): Promise<string> {
313+
if (!this._authCodePromise) {
314+
await this._ensureCallbackServer()
315+
}
316+
return this._authCodePromise!
272317
}
273318

274319
/**
@@ -327,11 +372,54 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
327372
await this.saveTokens(tokens)
328373
}
329374

375+
/**
376+
* Refreshes the access token using a refresh token.
377+
* @param refreshToken The refresh token to use.
378+
* @returns The new tokens.
379+
*/
380+
async refreshAccessToken(refreshToken: string): Promise<OAuthTokens> {
381+
if (!this._authServerMeta?.token_endpoint) {
382+
throw new Error("No token_endpoint in auth server metadata — cannot refresh token")
383+
}
384+
if (!this._clientInfo) {
385+
throw new Error("No client information — registerClientIfNeeded() must be called first")
386+
}
387+
388+
const params: Record<string, string> = {
389+
grant_type: "refresh_token",
390+
refresh_token: refreshToken,
391+
client_id: this._clientInfo.client_id,
392+
}
393+
394+
if (this._tokenEndpointAuthMethod === "client_secret_post" && this._clientInfo.client_secret) {
395+
params.client_secret = this._clientInfo.client_secret
396+
}
397+
398+
const response = await fetch(this._authServerMeta.token_endpoint as string, {
399+
method: "POST",
400+
headers: {
401+
"Content-Type": "application/x-www-form-urlencoded",
402+
Accept: "application/json",
403+
},
404+
body: new URLSearchParams(params).toString(),
405+
})
406+
407+
if (!response.ok) {
408+
throw new Error(`Token refresh failed: HTTP ${response.status}`)
409+
}
410+
411+
const tokens = (await response.json()) as OAuthTokens
412+
await this.saveTokens(tokens)
413+
return tokens
414+
}
415+
330416
/** Close the local callback server. Always call this when done. */
331417
async close(): Promise<void> {
332-
if (!this._closed) {
418+
if (!this._closed && this._server) {
333419
this._closed = true
334420
await stopCallbackServer(this._server).catch(() => {})
421+
this._server = null
422+
this._authCodePromise = null
335423
}
336424
}
337425
}

src/services/mcp/SecretStorageService.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ export class SecretStorageService {
2525
}
2626

2727
private _key(serverUrl: string): string {
28-
return `${this._namespace}${new URL(serverUrl).host}.data`
28+
const url = new URL(serverUrl)
29+
// Use host + pathname for stricter isolation between different MCP servers on the same host.
30+
// We sanitize the pathname to ensure it's a valid key component.
31+
const sanitizedPath = url.pathname.replace(/[^a-zA-Z0-9]/g, "_").replace(/^_+|_+$/g, "")
32+
const pathSuffix = sanitizedPath ? `.${sanitizedPath}` : ""
33+
return `${this._namespace}${url.host}${pathSuffix}.data`
2934
}
3035

3136
async getOAuthData(serverUrl: string): Promise<StoredMcpOAuthData | undefined> {

src/services/mcp/__tests__/McpOAuthClientProvider.spec.ts

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,46 @@ describe("McpOAuthClientProvider", () => {
178178
await provider.close()
179179
})
180180

181-
it("should return undefined for expired tokens", async () => {
181+
it("should refresh tokens when access token is expired but refresh token exists", async () => {
182+
setupCallbackServerMock()
183+
const secretStorage = createMockSecretStorage()
184+
const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage)
185+
186+
const initialTokens = {
187+
access_token: "expired-access",
188+
refresh_token: "valid-refresh",
189+
token_type: "Bearer",
190+
}
191+
const refreshedTokens = {
192+
access_token: "new-access",
193+
refresh_token: "new-refresh",
194+
token_type: "Bearer",
195+
expires_in: 3600,
196+
}
197+
198+
await provider.saveClientInformation({ client_id: "id", redirect_uris: [] } as any)
199+
await secretStorage.saveOAuthData("https://example.com/mcp", {
200+
tokens: initialTokens,
201+
expires_at: Date.now() - 1000,
202+
})
203+
204+
mockFetch.mockResolvedValueOnce({
205+
ok: true,
206+
json: () => Promise.resolve(refreshedTokens),
207+
})
208+
209+
const result = await provider.tokens()
210+
expect(result).toEqual(refreshedTokens)
211+
expect(mockFetch).toHaveBeenCalledWith(
212+
expect.stringContaining("/token"),
213+
expect.objectContaining({
214+
body: expect.stringContaining("grant_type=refresh_token"),
215+
}),
216+
)
217+
await provider.close()
218+
})
219+
220+
it("should return undefined for expired tokens without refresh token", async () => {
182221
setupCallbackServerMock()
183222
const secretStorage = createMockSecretStorage()
184223
const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage)

0 commit comments

Comments
 (0)