Source code for redisvl.utils.vectorize.text.custom

from typing import Any, Callable, List, Optional, Union

from pydantic import PrivateAttr

from redisvl.utils.utils import deprecated_argument
from redisvl.utils.vectorize.base import BaseVectorizer


def _check_vector(result: list, method_name: str) -> None:
    """
    Validates the structure of returned embeddings.

    - For methods named "*_many", expects a list of lists of floats.
    - For single methods, expects a list of floats.

    Raises:
        ValueError: If the embeddings do not match the expected structure.
    """
    if method_name.endswith("_many"):
        # embed_many / aembed_many → list of lists
        if not isinstance(result, list) or not result:
            raise ValueError(f"{method_name} must return a non-empty list of lists.")
        if not isinstance(result[0], list) or not result[0]:
            raise ValueError(f"{method_name} must return a list of non-empty lists.")
        if not isinstance(result[0][0], float):
            raise ValueError(f"{method_name} must return a list of lists of floats.")
    else:
        # embed / aembed → a single list of floats
        if not isinstance(result, list) or not result:
            raise ValueError(f"{method_name} must return a non-empty list.")
        if not isinstance(result[0], float):
            raise ValueError(f"{method_name} must return a list of floats.")


def validate_async(method):
    """
    Decorator that lazily validates the output of async methods (aembed, aembed_many).
    On first call, it checks the returned embeddings with _check_vector, then sets a flag
    so subsequent calls skip re-validation.
    """

    async def wrapper(self, *args, **kwargs):
        result = await method(self, *args, **kwargs)
        method_name = method.__name__
        validated_attr = f"_{method_name}_validated"

        try:
            if not getattr(self, validated_attr):
                _check_vector(result, method_name)
                setattr(self, validated_attr, True)
        except Exception as e:
            raise ValueError(f"Invalid embedding method: {e}")

        return result

    return wrapper


[docs] class CustomTextVectorizer(BaseVectorizer): """The CustomTextVectorizer class wraps user-defined embedding methods to create embeddings for text data. This vectorizer is designed to accept a provided callable text vectorizer and provides a class definition to allow for compatibility with RedisVL. The vectorizer may support both synchronous and asynchronous operations which allows for batch processing of texts, but at a minimum only syncronous embedding is required to satisfy the 'embed()' method. .. code-block:: python # Synchronous embedding of a single text vectorizer = CustomTextVectorizer( embed = my_vectorizer.generate_embedding ) embedding = vectorizer.embed("Hello, world!") # Asynchronous batch embedding of multiple texts embeddings = await vectorizer.aembed_many( ["Hello, world!", "How are you?"], batch_size=2 ) """ # User-provided callables _embed: Callable = PrivateAttr() _embed_many: Optional[Callable] = PrivateAttr() _aembed: Optional[Callable] = PrivateAttr() _aembed_many: Optional[Callable] = PrivateAttr() # Validation flags for async methods _aembed_validated: bool = PrivateAttr(default=False) _aembed_many_validated: bool = PrivateAttr(default=False) def __init__( self, embed: Callable, embed_many: Optional[Callable] = None, aembed: Optional[Callable] = None, aembed_many: Optional[Callable] = None, dtype: str = "float32", ): """Initialize the Custom vectorizer. Args: embed (Callable): a Callable function that accepts a string object and returns a list of floats. embed_many (Optional[Callable)]: a Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. aembed (Optional[Callable]): an asyncronous Callable function that accepts a string object and returns a lists of floats. Defaults to None. aembed_many (Optional[Callable]): an asyncronous Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. Raises: ValueError: if embedding validation fails. """ super().__init__(model=self.type, dtype=dtype) # Store user-provided callables self._embed = embed self._embed_many = embed_many self._aembed = aembed self._aembed_many = aembed_many # Set dims self.dims = self._validate_sync_callables() @property def type(self) -> str: return "custom" def _validate_sync_callables(self) -> int: """ Validate the sync embed function with a test call and discover the dimension. Optionally validate embed_many if provided. Returns the discovered dimension. Raises: ValueError: If embed or embed_many produce malformed results or fail entirely. """ # Check embed try: test_single = self._embed("dimension test") _check_vector(test_single, "embed") dims = len(test_single) except Exception as e: raise ValueError(f"Invalid embedding method: {e}") # Check embed_many if self._embed_many: try: test_batch = self._embed_many(["dimension test (many)"]) _check_vector(test_batch, "embed_many") except Exception as e: raise ValueError(f"Invalid embedding method: {e}") return dims
[docs] @deprecated_argument("dtype") def embed( self, text: str, preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, ) -> Union[List[float], bytes]: """ Generate an embedding for a single piece of text using your sync embed function. Args: text (str): The text to embed. preprocess (Optional[Callable]): An optional callable to preprocess the text. as_buffer (bool): If True, return the embedding as a byte buffer. Returns: Union[List[float], bytes]: The embedding of the input text. Raises: TypeError: If the input is not a string. """ if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") if preprocess: text = preprocess(text) dtype = kwargs.pop("dtype", self.dtype) try: result = self._embed(text, **kwargs) except Exception as e: raise ValueError(f"Embedding text failed: {e}") return self._process_embedding(result, as_buffer, dtype)
[docs] @deprecated_argument("dtype") def embed_many( self, texts: List[str], preprocess: Optional[Callable] = None, batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: """ Generate embeddings for multiple pieces of text in batches using your sync embed_many function. Args: texts (List[str]): A list of texts to embed. preprocess (Optional[Callable]): Optional preprocessing for each text. batch_size (int): Number of texts per batch. as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. NotImplementedError: If no embed_many function was provided. """ if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") if texts and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") if not self._embed_many: raise NotImplementedError("No embed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): results = self._embed_many(batch, **kwargs) processed = [ self._process_embedding(r, as_buffer, dtype) for r in results ] embeddings.extend(processed) except Exception as e: raise ValueError(f"Embedding text failed: {e}") return embeddings
[docs] @validate_async @deprecated_argument("dtype") async def aembed( self, text: str, preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, ) -> List[float]: """ Asynchronously generate an embedding for a single piece of text. Args: text (str): The text to embed. preprocess (Optional[Callable]): An optional callable to preprocess the text. as_buffer (bool): If True, return the embedding as a byte buffer. Returns: List[float]: The embedding of the input text. Raises: TypeError: If the input is not a string. NotImplementedError: If no aembed function was provided. """ if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") if not self._aembed: raise NotImplementedError("No aembed function was provided.") if preprocess: text = preprocess(text) dtype = kwargs.pop("dtype", self.dtype) try: result = await self._aembed(text, **kwargs) except Exception as e: raise ValueError(f"Embedding text failed: {e}") return self._process_embedding(result, as_buffer, dtype)
[docs] @validate_async @deprecated_argument("dtype") async def aembed_many( self, texts: List[str], preprocess: Optional[Callable] = None, batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: """ Asynchronously generate embeddings for multiple pieces of text in batches. Args: texts (List[str]): The texts to embed. preprocess (Optional[Callable]): Optional preprocessing for each text. batch_size (int): Number of texts per batch. as_buffer (bool): If True, convert each embedding to a byte buffer. Returns: Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. Raises: TypeError: If the input is not a list of strings. NotImplementedError: If no aembed_many function was provided. """ if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") if texts and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") if not self._aembed_many: raise NotImplementedError("No aembed_many function was provided.") dtype = kwargs.pop("dtype", self.dtype) embeddings: Union[List[List[float]], List[bytes]] = [] try: for batch in self.batchify(texts, batch_size, preprocess): results = await self._aembed_many(batch, **kwargs) processed = [ self._process_embedding(r, as_buffer, dtype) for r in results ] embeddings.extend(processed) except Exception as e: raise ValueError(f"Embedding text failed: {e}") return embeddings