Source code for redisvl.query.aggregate

from typing import Any, Dict, List, Optional, Set, Tuple, Union

from redis.commands.search.aggregation import AggregateRequest, Desc

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.utils.token_escaper import TokenEscaper


class AggregationQuery(AggregateRequest):
    """
    Base class for aggregation queries used to create aggregation queries for Redis.
    """

    def __init__(self, query_string):
        super().__init__(query_string)


[docs] class HybridQuery(AggregationQuery): """ HybridQuery combines text and vector search in Redis. It allows you to perform a hybrid search using both text and vector similarity. It scores documents based on a weighted combination of text and vector similarity. .. code-block:: python from redisvl.query import HybridQuery from redisvl.index import SearchIndex index = SearchIndex.from_yaml("path/to/index.yaml") query = HybridQuery( text="example text", text_field_name="text_field", vector=[0.1, 0.2, 0.3], vector_field_name="vector_field", text_scorer="BM25STD", filter_expression=None, alpha=0.7, dtype="float32", num_results=10, return_fields=["field1", "field2"], stopwords="english", dialect=2, ) results = index.query(query) """ DISTANCE_ID: str = "vector_distance" VECTOR_PARAM: str = "vector" def __init__( self, text: str, text_field_name: str, vector: Union[bytes, List[float]], vector_field_name: str, text_scorer: str = "BM25STD", filter_expression: Optional[Union[str, FilterExpression]] = None, alpha: float = 0.7, dtype: str = "float32", num_results: int = 10, return_fields: Optional[List[str]] = None, stopwords: Optional[Union[str, Set[str]]] = "english", dialect: int = 2, ): """ Instantiates a HybridQuery object. Args: text (str): The text to search for. text_field_name (str): The text field name to search in. vector (Union[bytes, List[float]]): The vector to perform vector similarity search. vector_field_name (str): The vector field name to search in. text_scorer (str, optional): The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM, BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD". filter_expression (Optional[FilterExpression], optional): The filter expression to use. Defaults to None. alpha (float, optional): The weight of the vector similarity. Documents will be scored as: hybrid_score = (alpha) * vector_score + (1-alpha) * text_score. Defaults to 0.7. dtype (str, optional): The data type of the vector. Defaults to "float32". num_results (int, optional): The number of results to return. Defaults to 10. return_fields (Optional[List[str]], optional): The fields to return. Defaults to None. stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the provided text prior to searchuse. If a string such as "english" "german" is provided then a default set of stopwords for that language will be used. if a list, set, or tuple of strings is provided then those will be used as stopwords. Defaults to "english". if set to "None" then no stopwords will be removed. dialect (int, optional): The Redis dialect version. Defaults to 2. Raises: ValueError: If the text string is empty, or if the text string becomes empty after stopwords are removed. TypeError: If the stopwords are not a set, list, or tuple of strings. """ if not text.strip(): raise ValueError("text string cannot be empty") self._text = text self._text_field = text_field_name self._vector = vector self._vector_field = vector_field_name self._filter_expression = filter_expression self._alpha = alpha self._dtype = dtype self._num_results = num_results self._set_stopwords(stopwords) query_string = self._build_query_string() super().__init__(query_string) self.scorer(text_scorer) # type: ignore[attr-defined] self.add_scores() # type: ignore[attr-defined] self.apply( vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score" ) self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity") self.sort_by(Desc("@hybrid_score"), max=num_results) self.dialect(dialect) # type: ignore[attr-defined] if return_fields: self.load(*return_fields) @property def params(self) -> Dict[str, Any]: """Return the parameters for the aggregation. Returns: Dict[str, Any]: The parameters for the aggregation. """ if isinstance(self._vector, bytes): vector = self._vector else: vector = array_to_buffer(self._vector, dtype=self._dtype) params = {self.VECTOR_PARAM: vector} return params @property def stopwords(self) -> Set[str]: """Return the stopwords used in the query. Returns: Set[str]: The stopwords used in the query. """ return self._stopwords.copy() if self._stopwords else set() def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"): """Set the stopwords to use in the query. Args: stopwords (Optional[Union[str, Set[str]]]): The stopwords to use. If a string such as "english" "german" is provided then a default set of stopwords for that language will be used. if a list, set, or tuple of strings is provided then those will be used as stopwords. Defaults to "english". if set to "None" then no stopwords will be removed. Raises: TypeError: If the stopwords are not a set, list, or tuple of strings. """ if not stopwords: self._stopwords = set() elif isinstance(stopwords, str): # Lazy import because nltk is an optional dependency try: import nltk from nltk.corpus import stopwords as nltk_stopwords except ImportError: raise ValueError( f"Loading stopwords for {stopwords} failed: nltk is not installed." ) try: nltk.download("stopwords", quiet=True) self._stopwords = set(nltk_stopwords.words(stopwords)) except Exception as e: raise ValueError(f"Error trying to load {stopwords} from nltk. {e}") elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore isinstance(word, str) for word in stopwords ): self._stopwords = set(stopwords) else: raise TypeError("stopwords must be a set, list, or tuple of strings") def _tokenize_and_escape_query(self, user_query: str) -> str: """Convert a raw user query to a redis full text query joined by ORs Args: user_query (str): The user query to tokenize and escape. Returns: str: The tokenized and escaped query string. Raises: ValueError: If the text string becomes empty after stopwords are removed. """ escaper = TokenEscaper() tokens = [ escaper.escape( token.strip().strip(",").replace("“", "").replace("”", "").lower() ) for token in user_query.split() ] tokenized = " | ".join( [token for token in tokens if token and token not in self._stopwords] ) if not tokenized: raise ValueError("text string cannot be empty after removing stopwords") return tokenized def _build_query_string(self) -> str: """Build the full query string for text search with optional filtering.""" if isinstance(self._filter_expression, FilterExpression): filter_expression = str(self._filter_expression) else: filter_expression = "" # base KNN query knn_query = f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}" text = f"(~@{self._text_field}:({self._tokenize_and_escape_query(self._text)})" if filter_expression and filter_expression != "*": text += f" AND {filter_expression}" return f"{text})=>[{knn_query}]"