Source code for redisvl.utils.rerank.hf_cross_encoder

from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic.v1 import PrivateAttr

from redisvl.utils.rerank.base import BaseReranker


[docs] class HFCrossEncoderReranker(BaseReranker): """ The HFCrossEncoderReranker class uses a cross-encoder models from Hugging Face to rerank documents based on an input query. This reranker loads a cross-encoder model using the `CrossEncoder` class from the `sentence_transformers` library. It requires the `sentence_transformers` library to be installed. .. code-block:: python from redisvl.utils.rerank import HFCrossEncoderReranker # set up the HFCrossEncoderReranker with a specific model reranker = HFCrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", limit=3) # rerank raw search results based on user input/query results = reranker.rank( query="your input query text here", docs=[ {"content": "document 1"}, {"content": "document 2"}, {"content": "document 3"} ] ) """ _client: Any = PrivateAttr() def __init__( self, model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", limit: int = 3, return_score: bool = True, **kwargs, ) -> None: """ Initialize the HFCrossEncoderReranker with a specified model and ranking criteria. Parameters: model (str): The name or path of the cross-encoder model to use for reranking. Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'. limit (int): The maximum number of results to return after reranking. Must be a positive integer. return_score (bool): Whether to return scores alongside the reranked results. """ model = model or kwargs.pop("model_name", None) super().__init__( model=model, rank_by=None, limit=limit, return_score=return_score ) self._initialize_client(**kwargs) def _initialize_client(self, **kwargs): """ Setup the huggingface cross-encoder client using optional kwargs. """ # Dynamic import of the sentence-transformers module try: from sentence_transformers import CrossEncoder except ImportError: raise ImportError( "HFCrossEncoder reranker requires the sentence-transformers library. \ Please install with `pip install sentence-transformers`" ) self._client = CrossEncoder(self.model, **kwargs)
[docs] def rank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Rerank documents based on the provided query using the loaded cross-encoder model. This method processes the user's query and the provided documents to rerank them in a manner that is potentially more relevant to the query's context. Parameters: query (str): The user's search query. docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked, either as dictionaries or strings. Returns: Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. """ limit = kwargs.get("limit", self.limit) return_score = kwargs.get("return_score", self.return_score) if not query: raise ValueError("query cannot be empty") if not isinstance(query, str): raise TypeError("query must be a string") if not isinstance(docs, list): raise TypeError("docs must be a list") if not docs: return [] if not return_score else ([], []) if all(isinstance(doc, dict) for doc in docs): texts = [ str(doc["content"]) for doc in docs if isinstance(doc, dict) and "content" in doc ] doc_subset = [ doc for doc in docs if isinstance(doc, dict) and "content" in doc ] else: texts = [str(doc) for doc in docs] doc_subset = [{"content": doc} for doc in docs] scores = self._client.predict([(query, text) for text in texts]) scores = [float(score) for score in scores] docs_with_scores = list(zip(doc_subset, scores)) docs_with_scores.sort(key=lambda x: x[1], reverse=True) reranked_docs = [doc for doc, _ in docs_with_scores[:limit]] scores = scores[:limit] if return_score: return reranked_docs, scores # type: ignore return reranked_docs
[docs] async def arank( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs ) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: """ Asynchronously rerank documents based on the provided query using the loaded cross-encoder model. This method processes the user's query and the provided documents to rerank them in a manner that is potentially more relevant to the query's context. Parameters: query (str): The user's search query. docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked, either as dictionaries or strings. Returns: Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]: The reranked list of documents and optionally associated scores. """ return self.rank(query, docs, **kwargs)