Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import json
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

from fastembed.common.types import NumpyArray, OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir, normalize
Expand Down Expand Up @@ -180,6 +184,24 @@
sources=ModelSource(hf="jinaai/jina-clip-v1"),
model_file="onnx/text_model.onnx",
),
DenseModelDescription(
model="jinaai/jina-embeddings-v3",
dim=1024,
description=(
"Text embeddings, Unimodal (text), Multilingual (89+ languages), 8192 input tokens truncation, "
"Task-specific LoRA adapters (retrieval, classification, text-matching, clustering), "
"Matryoshka dimensions: 32-1024, 2024 year."
),
license="apache-2.0",
size_in_GB=2.29,
sources=ModelSource(hf="jinaai/jina-embeddings-v3"),
model_file="onnx/model.onnx",
additional_files=["onnx/model.onnx_data"],
tasks={
"query_task": "retrieval.query",
"passage_task": "retrieval.passage",
},
),
]


Expand Down Expand Up @@ -255,6 +277,14 @@ def __init__(
specific_model_path=self._specific_model_path,
)

# Load LoRA adaptations for models that support task-specific embeddings (e.g., Jina v3)
self.lora_adaptations: Optional[list[str]] = None
config_path = Path(self._model_dir) / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
self.lora_adaptations = config.get("lora_adaptations")

Comment on lines +288 to +332
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate lora_adaptations from config.json and fail fast for Jina v3 if missing.

Currently, non-list/empty values silently pass, which can lead to runtime shape/key errors later. Add minimal validation and a clear error for this model.

Apply this diff:

         self.lora_adaptations: Optional[list[str]] = None
         config_path = Path(self._model_dir) / "config.json"
         if config_path.exists():
             with open(config_path, "r") as f:
                 config = json.load(f)
-                self.lora_adaptations = config.get("lora_adaptations")
+                la = config.get("lora_adaptations")
+                if isinstance(la, list) and all(isinstance(x, str) for x in la):
+                    self.lora_adaptations = la
+                else:
+                    self.lora_adaptations = None
+
+        # Fail fast when Jina v3 is selected but LoRA metadata is unavailable
+        if (
+            self.model_description.model.lower() == "jinaai/jina-embeddings-v3"
+            and not self.lora_adaptations
+        ):
+            raise ValueError(
+                "Missing or invalid 'lora_adaptations' in config.json for jinaai/jina-embeddings-v3."
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Load LoRA adaptations for models that support task-specific embeddings (e.g., Jina v3)
self.lora_adaptations: Optional[list[str]] = None
config_path = Path(self._model_dir) / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
self.lora_adaptations = config.get("lora_adaptations")
# Load LoRA adaptations for models that support task-specific embeddings (e.g., Jina v3)
self.lora_adaptations: Optional[list[str]] = None
config_path = Path(self._model_dir) / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
la = config.get("lora_adaptations")
if isinstance(la, list) and all(isinstance(x, str) for x in la):
self.lora_adaptations = la
else:
self.lora_adaptations = None
# Fail fast when Jina v3 is selected but LoRA metadata is unavailable
if (
self.model_description.model.lower() == "jinaai/jina-embeddings-v3"
and not self.lora_adaptations
):
raise ValueError(
"Missing or invalid 'lora_adaptations' in config.json for jinaai/jina-embeddings-v3."
)
🤖 Prompt for AI Agents
In fastembed/text/onnx_embedding.py around lines 280 to 287, the code reads
lora_adaptations from config.json but doesn't validate it; add validation to
ensure config.get("lora_adaptations") is a non-empty list of strings and, if
not, raise a clear ValueError (fail fast) when this model requires task-specific
LoRA (e.g., Jina v3); specifically: after loading config, verify the key exists,
is a list, and each item is a string; set self.lora_adaptations to the validated
list, and if validation fails for a model that requires it, raise a descriptive
error explaining that lora_adaptations in config.json must be a non-empty list
of strings.

if not self.lazy_load:
self.load_onnx_model()

Expand Down Expand Up @@ -303,7 +333,20 @@ def _preprocess_onnx_input(
) -> dict[str, NumpyArray]:
"""
Preprocess the onnx input.
Adds task_id for models with LoRA adapters (e.g., Jina v3).
"""
# Handle task-specific embeddings for models with LoRA adapters
if self.lora_adaptations:
task_type = kwargs.get("task_type")

# If no task specified, use default (text-matching for general purpose)
if not task_type:
# Default to text-matching if available, otherwise first task
task_type = "text-matching" if "text-matching" in self.lora_adaptations else self.lora_adaptations[0]

if task_type in self.lora_adaptations:
task_id = np.array(self.lora_adaptations.index(task_type), dtype=np.int64)
onnx_input["task_id"] = task_id
return onnx_input

def _post_process_onnx_output(
Expand All @@ -329,6 +372,46 @@ def load_onnx_model(self) -> None:
device_id=self.device_id,
)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
"""
Embeds queries with task-specific handling for models that support it.

Args:
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
**kwargs: Additional keyword arguments.

Returns:
Iterable[NumpyArray]: The embeddings.
"""
# Use task-specific embedding for models with LoRA adapters
if self.model_description.tasks and "query_task" in self.model_description.tasks:
kwargs["task_type"] = self.model_description.tasks["query_task"]

if isinstance(query, str):
yield from self.embed([query], **kwargs)
else:
yield from self.embed(query, **kwargs)

def passage_embed(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
"""
Embeds passages with task-specific handling for models that support it.

Args:
texts (Union[str, Iterable[str]]): The text(s) to embed.
**kwargs: Additional keyword arguments.

Returns:
Iterable[NumpyArray]: The embeddings.
"""
# Use task-specific embedding for models with LoRA adapters
if self.model_description.tasks and "passage_task" in self.model_description.tasks:
kwargs["task_type"] = self.model_description.tasks["passage_task"]

if isinstance(texts, str):
yield from self.embed([texts], **kwargs)
else:
yield from self.embed(texts, **kwargs)


class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
def init_embedding(
Expand Down
62 changes: 62 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
"jinaai/jina-embeddings-v3": np.array([0.07257809, -0.08073004, 0.09241360, -0.01755937, 0.06534681]),
}

MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"]
Expand Down Expand Up @@ -175,3 +176,64 @@ def test_embedding_size() -> None:

if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize("model_name", MULTI_TASK_MODELS)
def test_multi_task_embedding(model_name: str) -> None:
"""Test models that support task-specific embeddings (query vs passage)."""
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

# Skip in CI unless manual
if is_ci and not is_manual:
pytest.skip("Skipping multi-task model tests in CI (large models)")

model_desc = None
for desc in TextEmbedding._list_supported_models():
if desc.model == model_name:
model_desc = desc
break

assert model_desc is not None, f"Model {model_name} not found in supported models"

dim = model_desc.dim
model = TextEmbedding(model_name=model_name)

# Test query embedding
queries = ["What is the capital of France?", "How does photosynthesis work?"]
query_embeddings = list(model.query_embed(queries))
query_embeddings = np.stack(query_embeddings, axis=0)
assert query_embeddings.shape == (2, dim), f"Query embeddings shape mismatch for {model_name}"

# Test passage embedding
passages = ["Paris is the capital of France.", "Photosynthesis is a process used by plants."]
passage_embeddings = list(model.passage_embed(passages))
passage_embeddings = np.stack(passage_embeddings, axis=0)
assert passage_embeddings.shape == (2, dim), f"Passage embeddings shape mismatch for {model_name}"

# Test regular embed (should work without task specification)
docs = ["hello world", "flag embedding"]
embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (2, dim), f"Regular embeddings shape mismatch for {model_name}"

# Verify that query and passage embeddings are different (due to different LoRA adapters)
# Using the same text should produce different embeddings for query vs passage
test_text = "This is a test sentence"
query_emb = np.array(list(model.query_embed([test_text])))
passage_emb = np.array(list(model.passage_embed([test_text])))

# They should not be identical (different task adapters)
assert not np.allclose(query_emb, passage_emb, atol=1e-6), \
f"Query and passage embeddings should differ for {model_name}"

# Optional: Check canonical vectors if available
if model_name in CANONICAL_VECTOR_VALUES:
canonical_vector = CANONICAL_VECTOR_VALUES[model_name]
# Check against regular embeddings[0] which is "hello world"
assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), f"Canonical vector mismatch for {model_name}"

if is_ci:
delete_model_cache(model.model._model_dir)