Skip to content

Commit 5f77e85

Browse files
jsklanclaude
andcommitted
fix: handle embedding_types dict response in AWS client embed methods
When embedding_types is specified, the Cohere API returns embeddings as a dict (e.g. {"float": [[...]], "int8": [[...]]}) instead of a flat list. Both _bedrock_embed and _sagemaker_embed now detect the dict format and return it directly instead of wrapping it in Embeddings, which would silently produce wrong results for len() and iteration. Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent bae43fa commit 5f77e85

2 files changed

Lines changed: 53 additions & 3 deletions

File tree

src/cohere/manually_maintained/cohere_aws/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def embed(
564564
model_id: Optional[str] = None,
565565
output_dimension: Optional[int] = None,
566566
embedding_types: Optional[List[str]] = None,
567-
) -> Embeddings:
567+
) -> Union[Embeddings, Dict[str, List]]:
568568
json_params = {
569569
'texts': texts,
570570
'truncate': truncate,
@@ -607,7 +607,10 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
607607
# ValidationError, e.g. when variant is bad
608608
raise CohereError(str(e))
609609

610-
return Embeddings(response['embeddings'])
610+
embeddings = response['embeddings']
611+
if isinstance(embeddings, dict):
612+
return embeddings
613+
return Embeddings(embeddings)
611614

612615
def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
613616
if not model_id:
@@ -628,7 +631,10 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
628631
# ValidationError, e.g. when variant is bad
629632
raise CohereError(str(e))
630633

631-
return Embeddings(response['embeddings'])
634+
embeddings = response['embeddings']
635+
if isinstance(embeddings, dict):
636+
return embeddings
637+
return Embeddings(embeddings)
632638

633639

634640
def rerank(self,

tests/test_aws_client_unit.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,47 @@ def test_embed_omits_none_params(self) -> None:
206206

207207
self.assertNotIn("output_dimension", captured_body)
208208
self.assertNotIn("embedding_types", captured_body)
209+
210+
def test_embed_with_embedding_types_returns_dict(self) -> None:
211+
"""When embedding_types is specified, the API returns embeddings as a dict.
212+
The client should return that dict rather than wrapping it in Embeddings."""
213+
mock_boto3 = MagicMock()
214+
mock_botocore = MagicMock()
215+
216+
by_type_embeddings = {"float": [[0.1, 0.2]], "int8": [[1, 2]]}
217+
218+
def fake_invoke_model(**kwargs): # type: ignore
219+
mock_body = MagicMock()
220+
mock_body.read.return_value = json.dumps({
221+
"embeddings": by_type_embeddings,
222+
"response_type": "embeddings_by_type",
223+
}).encode()
224+
return {"body": mock_body}
225+
226+
mock_bedrock_client = MagicMock()
227+
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model
228+
229+
def fake_boto3_client(service_name, **kwargs): # type: ignore
230+
if service_name == "bedrock-runtime":
231+
return mock_bedrock_client
232+
return MagicMock()
233+
234+
mock_boto3.client.side_effect = fake_boto3_client
235+
236+
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
237+
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
238+
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
239+
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
240+
241+
from cohere.manually_maintained.cohere_aws.client import Client
242+
243+
client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
244+
result = client.embed(
245+
texts=["hello world"],
246+
input_type="search_document",
247+
model_id="cohere.embed-english-v3",
248+
embedding_types=["float", "int8"],
249+
)
250+
251+
self.assertIsInstance(result, dict)
252+
self.assertEqual(result, by_type_embeddings)

0 commit comments

Comments
 (0)