Skip to content

Commit e542ef0

Browse files
fern-supportclaude
andcommitted
fix: resolve AWS client SigV4 signing, forced SageMaker dep, and missing embed params
- Fix SigV4 host header mismatch: update copied headers dict with correct host after URL rewrite, so AWSRequest signs with the Bedrock/SageMaker host instead of stale api.cohere.com - Add mode parameter to cohere_aws.Client to conditionally initialize boto3 clients (bedrock-runtime/bedrock vs sagemaker-runtime/sagemaker), avoiding forced SageMaker dependency for Bedrock users - Add output_dimension and embedding_types params to embed() for Embed v4 Closes #721 Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent f366233 commit e542ef0

2 files changed

Lines changed: 16 additions & 5 deletions

File tree

src/cohere/aws_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None:
239239
)
240240
request.url = URL(url)
241241
request.headers["host"] = request.url.host
242+
headers["host"] = request.url.host
242243

243244
if endpoint == "rerank":
244245
body["api_version"] = get_api_version(version=api_version)

src/cohere/manually_maintained/cohere_aws/client.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,23 @@ class Client:
2020
def __init__(
2121
self,
2222
aws_region: typing.Optional[str] = None,
23+
mode: Mode = Mode.SAGEMAKER,
2324
):
2425
"""
2526
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
2627
`aws configure set region us-west-2` or override it with `region_name` parameter.
2728
"""
28-
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
29-
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
29+
self.mode = mode
3030
if os.environ.get('AWS_DEFAULT_REGION') is None:
3131
os.environ['AWS_DEFAULT_REGION'] = aws_region
32-
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
33-
self.mode = Mode.SAGEMAKER
32+
33+
if self.mode == Mode.SAGEMAKER:
34+
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
35+
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
36+
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
37+
elif self.mode == Mode.BEDROCK:
38+
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
39+
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
3440

3541

3642

@@ -550,11 +556,15 @@ def embed(
550556
variant: Optional[str] = None,
551557
input_type: Optional[str] = None,
552558
model_id: Optional[str] = None,
559+
output_dimension: Optional[int] = None,
560+
embedding_types: Optional[List[str]] = None,
553561
) -> Embeddings:
554562
json_params = {
555563
'texts': texts,
556564
'truncate': truncate,
557-
"input_type": input_type
565+
"input_type": input_type,
566+
"output_dimension": output_dimension,
567+
"embedding_types": embedding_types,
558568
}
559569
for key, value in list(json_params.items()):
560570
if value is None:

0 commit comments

Comments
 (0)