Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ so that your API key is not stored in source control.

### Workload Identity Authentication

For secure, automated environments like cloud-managed Kubernetes, Azure, and Google Cloud Platform, you can use workload identity authentication with short-lived tokens from cloud identity providers instead of long-lived API keys.
For secure, automated environments like cloud-managed Kubernetes, Azure, Google Cloud Platform, and AWS Bedrock, you can use workload identity authentication with short-lived tokens from cloud identity providers instead of long-lived API keys.

#### Kubernetes (service account tokens)

Expand Down Expand Up @@ -134,6 +134,43 @@ client = OpenAI(
)
```

#### AWS Bedrock

Requires `botocore` (`pip install 'openai[bedrock]'`). Credentials are resolved from the [standard AWS credential chain](https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html).

```python
from openai import OpenAI
from openai.auth import aws_bedrock_token_provider

client = OpenAI(
base_url="https://bedrock-mantle.us-east-1.api.aws/v1", # region must match the token provider
api_key=aws_bedrock_token_provider(
region="us-east-1",
profile="my-profile", # optional — defaults to the standard AWS credential chain
),
)

# List models supported by the OpenAI-compatible endpoint
for model in client.models.list().data:
print(model.id)
```

For `AsyncOpenAI`, use `async_aws_bedrock_token_provider`:

```python
from openai import AsyncOpenAI
from openai.auth import async_aws_bedrock_token_provider

client = AsyncOpenAI(
base_url="https://bedrock-mantle.us-east-1.api.aws/v1", # region must match the token provider
api_key=async_aws_bedrock_token_provider(
region="us-east-1",
),
)
```

> **Note:** The OpenAI SDK works only with Bedrock models that have the [OpenAI-compatible API](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-mantle.html) enabled. Use `client.models.list()` to see which models are available on your endpoint.

#### Custom subject token provider

```python
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"]
realtime = ["websockets >= 13, < 16"]
datalib = ["numpy >= 1", "pandas >= 1.2.3", "pandas-stubs >= 1.1.0.11"]
voice_helpers = ["sounddevice>=0.5.1", "numpy>=2.0.2"]
bedrock = ["botocore>=1.29.0"]

[tool.rye]
managed = true
Expand Down
4 changes: 4 additions & 0 deletions src/openai/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
SubjectTokenProvider as SubjectTokenProvider,
WorkloadIdentityAuth as WorkloadIdentityAuth,
gcp_id_token_provider as gcp_id_token_provider,
aws_bedrock_token_provider as aws_bedrock_token_provider,
async_aws_bedrock_token_provider as async_aws_bedrock_token_provider,
k8s_service_account_token_provider as k8s_service_account_token_provider,
azure_managed_identity_token_provider as azure_managed_identity_token_provider,
)
Expand All @@ -16,4 +18,6 @@
"k8s_service_account_token_provider",
"azure_managed_identity_token_provider",
"gcp_id_token_provider",
"aws_bedrock_token_provider",
"async_aws_bedrock_token_provider",
]
117 changes: 116 additions & 1 deletion src/openai/auth/_workload.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import os
import time
import base64
import threading
from typing import Any, Callable, TypedDict, cast
from typing import Any, Callable, Awaitable, TypedDict, cast
from pathlib import Path
from typing_extensions import Literal, NotRequired

Expand Down Expand Up @@ -173,6 +175,119 @@ def get_token() -> str:
return {"token_type": "id", "get_token": get_token}


def _make_bedrock_token_generator(
*,
region: str | None = None,
profile: str | None = None,
) -> Callable[[], str]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Add async-compatible Bedrock token provider

aws_bedrock_token_provider() currently returns a synchronous Callable[[], str], but AsyncOpenAI unconditionally awaits api_key providers in AsyncOpenAI._refresh_api_key. If an async user passes this new helper (the same way as the sync example), requests fail at runtime with a TypeError because str is not awaitable. This makes the new Bedrock auth path unusable for async clients unless users write their own wrapper.

Useful? React with 👍 / 👎.


_session: list[Any] = [None]

def get_token() -> str:
try:
import botocore.session
from botocore.auth import SigV4QueryAuth
from botocore.awsrequest import AWSRequest
except ImportError as e:
raise ImportError(
"botocore is required for AWS Bedrock token generation. Install it with: pip install 'openai[bedrock]'"
) from e

try:
resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
if not resolved_region:
raise SubjectTokenProviderError(
"AWS region must be provided via the 'region' parameter, "
"or the AWS_REGION / AWS_DEFAULT_REGION environment variable."
)

if _session[0] is None:
_session[0] = botocore.session.Session(profile=profile)

credentials = _session[0].get_credentials()
if credentials is None:
raise SubjectTokenProviderError("No AWS credentials found. Ensure your AWS credentials are configured.")
frozen_credentials = credentials.get_frozen_credentials()

request = AWSRequest(
method="POST",
url="https://bedrock.amazonaws.com/",
headers={"host": "bedrock.amazonaws.com"},
params={"Action": "CallWithBearerToken"},
)

signer = SigV4QueryAuth(frozen_credentials, "bedrock", resolved_region)
signer.add_auth(request)

signed_url = request.url
# Strip the https:// prefix before encoding
url_without_scheme = signed_url[len("https://") :]
encoded_token = base64.b64encode(f"{url_without_scheme}&Version=1".encode()).decode()

return f"bedrock-api-key-{encoded_token}"
except (ImportError, SubjectTokenProviderError):
raise
except Exception as e:
raise SubjectTokenProviderError(f"Failed to generate AWS Bedrock token: {e}") from e

return get_token


def aws_bedrock_token_provider(
*,
region: str | None = None,
profile: str | None = None,
) -> Callable[[], str]:
"""
Get a sync token provider for AWS Bedrock. Use with ``OpenAI``.

Returns a callable that generates a bearer token from a SigV4 presigned URL.
Pass it directly to ``api_key`` when creating an OpenAI client pointed at a
Bedrock runtime endpoint. Credentials are resolved from the standard AWS credential chain:
https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html

The botocore session is cached so credential resolution is efficient, while
the token itself is regenerated on each call to ensure it always reflects
the latest valid credentials (important for short-lived STS/assumed-role sessions).

For ``AsyncOpenAI``, use :func:`async_aws_bedrock_token_provider` instead.

Args:
region: AWS region. Must match the region in the ``base_url`` passed to the client.
Defaults to ``AWS_REGION`` or ``AWS_DEFAULT_REGION`` environment variable.
profile: AWS profile name. If not set, botocore resolves credentials from the standard chain.
"""
return _make_bedrock_token_generator(region=region, profile=profile)


def async_aws_bedrock_token_provider(
*,
region: str | None = None,
profile: str | None = None,
) -> Callable[[], Awaitable[str]]:
"""
Get an async token provider for AWS Bedrock. Use with ``AsyncOpenAI``.

Returns an async callable that generates a bearer token from a SigV4 presigned URL.
Pass it directly to ``api_key`` when creating an AsyncOpenAI client pointed at a
Bedrock runtime endpoint. Credentials are resolved from the standard AWS credential chain:
https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html

For ``OpenAI`` (sync), use :func:`aws_bedrock_token_provider` instead.

Args:
region: AWS region. Must match the region in the ``base_url`` passed to the client.
Defaults to ``AWS_REGION`` or ``AWS_DEFAULT_REGION`` environment variable.
profile: AWS profile name. If not set, botocore resolves credentials from the standard chain.
"""
_sync = _make_bedrock_token_generator(region=region, profile=profile)

async def get_token() -> str:
return await to_thread(_sync)

return get_token


class WorkloadIdentityAuth:
def __init__(
self,
Expand Down
73 changes: 73 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import json
import base64
from typing import cast
from pathlib import Path
from unittest.mock import MagicMock, patch

import httpx
import respx
Expand All @@ -9,8 +13,11 @@
from inline_snapshot import snapshot

from openai import OpenAI, OAuthError
from openai._exceptions import SubjectTokenProviderError
from openai.auth._workload import (
gcp_id_token_provider,
aws_bedrock_token_provider,
async_aws_bedrock_token_provider,
k8s_service_account_token_provider,
azure_managed_identity_token_provider,
)
Expand Down Expand Up @@ -188,3 +195,69 @@ def test_gcp_id_token_provider() -> None:

assert provider["token_type"] == "id"
assert provider["get_token"]() == "gcp-token"


def _mock_botocore() -> MagicMock:
"""Create a minimal mock botocore that stubs SigV4 signing."""
mock = MagicMock()
mock.session.Session.return_value.get_credentials.return_value.get_frozen_credentials.return_value = MagicMock()

def _fake_add_auth(request: MagicMock) -> None:
request.url += "&X-Amz-Credential=FAKE&X-Amz-Signature=FAKE"

mock.auth.SigV4QueryAuth.return_value.add_auth = _fake_add_auth
mock.awsrequest.AWSRequest.return_value = MagicMock(url="https://bedrock.amazonaws.com/?Action=CallWithBearerToken")

return mock


def _patch_botocore(mock: MagicMock): # type: ignore[type-arg]
return patch.dict(
"sys.modules",
{
"botocore": mock,
"botocore.session": mock.session,
"botocore.auth": mock.auth,
"botocore.awsrequest": mock.awsrequest,
},
)


def test_aws_bedrock_token_provider() -> None:
mock = _mock_botocore()

with _patch_botocore(mock):
token = aws_bedrock_token_provider(region="us-east-1")()
assert token.startswith("bedrock-api-key-")

decoded = base64.b64decode(token[len("bedrock-api-key-") :]).decode()
assert "bedrock.amazonaws.com" in decoded
assert "X-Amz-Signature=" in decoded
assert "Action=CallWithBearerToken" in decoded
assert "&Version=1" in decoded


def test_aws_bedrock_token_provider_no_credentials() -> None:
mock = MagicMock()
mock.session.Session.return_value.get_credentials.return_value = None

with _patch_botocore(mock):
with pytest.raises(SubjectTokenProviderError, match="No AWS credentials found"):
aws_bedrock_token_provider(region="us-east-1")()


def test_aws_bedrock_token_provider_no_botocore() -> None:
with patch.dict(
"sys.modules", {"botocore": None, "botocore.session": None, "botocore.auth": None, "botocore.awsrequest": None}
):
with pytest.raises(ImportError, match="botocore is required.*openai\\[bedrock\\]"):
aws_bedrock_token_provider(region="us-east-1")()


@pytest.mark.asyncio
async def test_async_aws_bedrock_token_provider() -> None:
mock = _mock_botocore()

with _patch_botocore(mock):
token = await async_aws_bedrock_token_provider(region="us-east-1")()
assert token.startswith("bedrock-api-key-")