-
Notifications
You must be signed in to change notification settings - Fork 971
Expand file tree
/
Copy pathembedding.py
More file actions
81 lines (70 loc) · 2.76 KB
/
embedding.py
File metadata and controls
81 lines (70 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Adapted from SakanaAI/ShinkaEvolve (Apache-2.0 License)
Original source: https://github.com/SakanaAI/ShinkaEvolve/blob/main/shinka/llm/embedding.py
"""
import os
import openai
from typing import Union, List
import logging
logger = logging.getLogger(__name__)
AZURE_EMBEDDING_MODELS = [
"azure-text-embedding-3-small",
"azure-text-embedding-3-large",
]
class EmbeddingClient:
def __init__(self, model_name: str = "text-embedding-3-small", base_url: str | None = None):
"""
Initialize the EmbeddingClient.
Args:
model_name: The embedding model name to use.
base_url: Optional base URL for the embedding API endpoint.
"""
self.client, self.model = self._get_client_model(model_name, base_url)
def _get_client_model(
self, model_name: str, base_url: str | None = None
) -> tuple[openai.OpenAI, str]:
if model_name in AZURE_EMBEDDING_MODELS:
# get rid of the azure- prefix
model_to_use = model_name.split("azure-")[-1]
client = openai.AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION"),
azure_endpoint=os.environ["AZURE_API_ENDPOINT"],
)
else:
# Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY
# This allows users to use OpenRouter for LLMs while using OpenAI for embeddings
embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY")
client = openai.OpenAI(api_key=embedding_api_key, base_url=base_url)
model_to_use = model_name
return client, model_to_use
def get_embedding(self, code: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
"""
Computes the text embedding for a code string.
Args:
code (str, list[str]): The code as a string or list
of strings.
Returns:
list: Embedding vector for the code or None if an error
occurs.
"""
if isinstance(code, str):
code = [code]
single_code = True
else:
single_code = False
try:
response = self.client.embeddings.create(
model=self.model, input=code, encoding_format="float"
)
# Extract embedding from response
if single_code:
return response.data[0].embedding
else:
return [d.embedding for d in response.data]
except Exception as e:
logger.info(f"Error getting embedding: {e}")
if single_code:
return [], 0.0
else:
return [[]], 0.0