|
10 | 10 | aws_region = os.getenv("AWS_REGION") |
11 | 11 | endpoint_type = os.getenv("ENDPOINT_TYPE") |
12 | 12 |
|
| 13 | + |
| 14 | +def _setup_boto3_env(): |
| 15 | + """Bridge custom test env vars to standard boto3 credential env vars.""" |
| 16 | + if aws_access_key: |
| 17 | + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key |
| 18 | + if aws_secret_key: |
| 19 | + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_key |
| 20 | + if aws_session_token: |
| 21 | + os.environ["AWS_SESSION_TOKEN"] = aws_session_token |
| 22 | + |
| 23 | + |
13 | 24 | @unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") |
14 | 25 | class TestClient(unittest.TestCase): |
15 | 26 | platform: str = "bedrock" |
@@ -109,3 +120,100 @@ def test_chat_stream(self) -> None: |
109 | 120 | self.assertIsNotNone(event.response.text) |
110 | 121 |
|
111 | 122 | self.assertSetEqual(response_types, {"text-generation", "stream-end"}) |
| 123 | + |
| 124 | + |
| 125 | +@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") |
| 126 | +class TestBedrockClientV2(unittest.TestCase): |
| 127 | + """Integration tests for BedrockClientV2 (httpx-based). |
| 128 | +
|
| 129 | + Fix 1 validation: If these pass, SigV4 signing uses the correct host header, |
| 130 | + since the request would fail with a signature mismatch otherwise. |
| 131 | + """ |
| 132 | + |
| 133 | + client: cohere.ClientV2 = cohere.BedrockClientV2( |
| 134 | + aws_access_key=aws_access_key, |
| 135 | + aws_secret_key=aws_secret_key, |
| 136 | + aws_session_token=aws_session_token, |
| 137 | + aws_region=aws_region, |
| 138 | + ) |
| 139 | + |
| 140 | + def test_embed(self) -> None: |
| 141 | + response = self.client.embed( |
| 142 | + model="cohere.embed-multilingual-v3", |
| 143 | + texts=["I love Cohere!"], |
| 144 | + input_type="search_document", |
| 145 | + embedding_types=["float"], |
| 146 | + ) |
| 147 | + self.assertIsNotNone(response) |
| 148 | + |
| 149 | + def test_embed_with_output_dimension(self) -> None: |
| 150 | + response = self.client.embed( |
| 151 | + model="cohere.embed-english-v3", |
| 152 | + texts=["I love Cohere!"], |
| 153 | + input_type="search_document", |
| 154 | + embedding_types=["float"], |
| 155 | + output_dimension=256, |
| 156 | + ) |
| 157 | + self.assertIsNotNone(response) |
| 158 | + |
| 159 | + |
| 160 | +@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") |
| 161 | +class TestCohereAwsBedrockClient(unittest.TestCase): |
| 162 | + """Integration tests for cohere_aws.Client in Bedrock mode (boto3-based). |
| 163 | +
|
| 164 | + Validates: |
| 165 | + - Fix 2: Client can be initialized with mode=BEDROCK without importing sagemaker |
| 166 | + - Fix 3: embed() accepts output_dimension and embedding_types |
| 167 | + """ |
| 168 | + |
| 169 | + @classmethod |
| 170 | + def setUpClass(cls) -> None: |
| 171 | + _setup_boto3_env() |
| 172 | + from cohere.manually_maintained.cohere_aws.client import Client |
| 173 | + from cohere.manually_maintained.cohere_aws.mode import Mode |
| 174 | + cls.client = Client(aws_region=aws_region, mode=Mode.BEDROCK) |
| 175 | + |
| 176 | + def test_client_is_bedrock_mode(self) -> None: |
| 177 | + from cohere.manually_maintained.cohere_aws.mode import Mode |
| 178 | + self.assertEqual(self.client.mode, Mode.BEDROCK) |
| 179 | + |
| 180 | + def test_embed(self) -> None: |
| 181 | + response = self.client.embed( |
| 182 | + texts=["I love Cohere!"], |
| 183 | + input_type="search_document", |
| 184 | + model_id="cohere.embed-multilingual-v3", |
| 185 | + ) |
| 186 | + self.assertIsNotNone(response) |
| 187 | + self.assertIsNotNone(response.embeddings) |
| 188 | + self.assertGreater(len(response.embeddings), 0) |
| 189 | + |
| 190 | + def test_embed_with_embedding_types(self) -> None: |
| 191 | + response = self.client.embed( |
| 192 | + texts=["I love Cohere!"], |
| 193 | + input_type="search_document", |
| 194 | + model_id="cohere.embed-multilingual-v3", |
| 195 | + embedding_types=["float"], |
| 196 | + ) |
| 197 | + self.assertIsNotNone(response) |
| 198 | + self.assertIsNotNone(response.embeddings) |
| 199 | + |
| 200 | + def test_embed_with_output_dimension(self) -> None: |
| 201 | + response = self.client.embed( |
| 202 | + texts=["I love Cohere!"], |
| 203 | + input_type="search_document", |
| 204 | + model_id="cohere.embed-english-v3", |
| 205 | + output_dimension=256, |
| 206 | + embedding_types=["float"], |
| 207 | + ) |
| 208 | + self.assertIsNotNone(response) |
| 209 | + self.assertIsNotNone(response.embeddings) |
| 210 | + |
| 211 | + def test_embed_without_new_params(self) -> None: |
| 212 | + """Backwards compat: embed() still works without the new v4 params.""" |
| 213 | + response = self.client.embed( |
| 214 | + texts=["I love Cohere!"], |
| 215 | + input_type="search_document", |
| 216 | + model_id="cohere.embed-multilingual-v3", |
| 217 | + ) |
| 218 | + self.assertIsNotNone(response) |
| 219 | + self.assertIsNotNone(response.embeddings) |
0 commit comments