Source code for redisvl.utils.vectorize.text.cohere

import os
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

from redisvl.utils.utils import deprecated_argument
from redisvl.utils.vectorize.base import BaseVectorizer

# ignore that cohere isn't imported
# mypy: disable-error-code="name-defined"


[docs] class CohereTextVectorizer(BaseVectorizer): """The CohereTextVectorizer class utilizes Cohere's API to generate embeddings for text data. This vectorizer is designed to interact with Cohere's /embed API, requiring an API key for authentication. The key can be provided directly in the `api_config` dictionary or through the `COHERE_API_KEY` environment variable. User must obtain an API key from Cohere's website (https://dashboard.cohere.com/). Additionally, the `cohere` python client must be installed with `pip install cohere`. The vectorizer supports only synchronous operations, allows for batch processing of texts and flexibility in handling preprocessing tasks. .. code-block:: python from redisvl.utils.vectorize import CohereTextVectorizer vectorizer = CohereTextVectorizer( model="embed-english-v3.0", api_config={"api_key": "your-cohere-api-key"} # OR set COHERE_API_KEY in your env ) query_embedding = vectorizer.embed( text="your input query text here", input_type="search_query" ) doc_embeddings = cohere.embed_many( texts=["your document text", "more document text"], input_type="search_document" ) """ _client: Any = PrivateAttr() def __init__( self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None, dtype: str = "float32", **kwargs, ): """Initialize the Cohere vectorizer. Visit https://cohere.ai/embed to learn about embeddings. Args: model (str): Model to use for embedding. Defaults to 'embed-english-v3.0'. api_config (Optional[Dict], optional): Dictionary containing the API key. Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). 'float32' will use Cohere's float embeddings, 'int8' and 'uint8' will map to Cohere's corresponding embedding types. Defaults to 'float32'. Raises: ImportError: If the cohere library is not installed. ValueError: If the API key is not provided. ValueError: If an invalid dtype is provided. """ super().__init__(model=model, dtype=dtype) # Init client self._initialize_client(api_config, **kwargs) # Set model dimensions after init self.dims = self._set_model_dims() def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the Cohere clients using the provided API key or an environment variable. """ if api_config is None: api_config = {} # Dynamic import of the cohere module try: from cohere import Client except ImportError: raise ImportError( "Cohere vectorizer requires the cohere library. \ Please install with `pip install cohere`" ) api_key = ( api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY") ) if not api_key: raise ValueError( "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) self._client = Client(api_key=api_key, client_name="redisvl", **kwargs) def _set_model_dims(self) -> int: try: embedding = self.embed("dimension check", input_type="search_document") except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") return len(embedding) def _get_cohere_embedding_type(self, dtype: str) -> List[str]: """Map dtype to appropriate Cohere embedding_types value.""" if dtype == "int8": return ["int8"] elif dtype == "uint8": return ["uint8"] else: return ["float"]
[docs] @deprecated_argument("dtype") def embed( self, text: str, preprocess: Optional[Callable] = None, as_buffer: bool = False, **kwargs, ) -> Union[List[float], List[int], bytes]: """Embed a chunk of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method that specifies the type of input you're giving to the model. Supported input types: - ``search_document``: Used for embeddings stored in a vector database for search use-cases. - ``search_query``: Used for embeddings of search queries run against a vector DB to find relevant documents. - ``classification``: Used for embeddings passed through a text classifier - ``clustering``: Used for the embeddings run through a clustering algorithm. When hydrating your Redis DB, the documents you want to search over should be embedded with input_type= "search_document" and when you are querying the database, you should set the input_type = "search query". If you want to use the embeddings for a classification or clustering task downstream, you should set input_type= "classification" or "clustering". Args: text (str): Chunk of text to embed. preprocess (Optional[Callable], optional): Optional preprocessing callable to perform before vectorization. Defaults to None. as_buffer (bool, optional): Whether to convert the raw embedding to a byte string. Defaults to False. input_type (str): Specifies the type of input passed to the model. Required for embedding models v3 and higher. Returns: Union[List[float], List[int], bytes]: - If as_buffer=True: Returns a bytes object - If as_buffer=False: - For dtype="float32": Returns a list of floats - For dtype="int8" or "uint8": Returns a list of integers Raises: TypeError: In an invalid input_type is provided. """ input_type = kwargs.pop("input_type", None) if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") if not isinstance(input_type, str): raise TypeError( "Must pass in a str value for cohere embedding input_type. \ See https://docs.cohere.com/reference/embed." ) if preprocess: text = preprocess(text) dtype = kwargs.pop("dtype", self.dtype) # Check if embedding_types was provided and warn user if "embedding_types" in kwargs: warnings.warn( "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", UserWarning, stacklevel=2, ) kwargs.pop("embedding_types") # Map dtype to appropriate embedding_type embedding_types = self._get_cohere_embedding_type(dtype) response = self._client.embed( texts=[text], model=self.model, input_type=input_type, embedding_types=embedding_types, **kwargs, ) # Extract the appropriate embedding based on embedding_types embed_type = embedding_types[0] if hasattr(response.embeddings, embed_type): embedding = getattr(response.embeddings, embed_type)[0] else: embedding = response.embeddings[0] # Fallback for older API versions return self._process_embedding(embedding, as_buffer, dtype)
[docs] @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) @deprecated_argument("dtype") def embed_many( self, texts: List[str], preprocess: Optional[Callable] = None, batch_size: int = 10, as_buffer: bool = False, **kwargs, ) -> Union[List[List[float]], List[List[int]], List[bytes]]: """Embed many chunks of text using the Cohere Embeddings API. Must provide the embedding `input_type` as a `kwarg` to this method that specifies the type of input you're giving to the model. Supported input types: - ``search_document``: Used for embeddings stored in a vector database for search use-cases. - ``search_query``: Used for embeddings of search queries run against a vector DB to find relevant documents. - ``classification``: Used for embeddings passed through a text classifier - ``clustering``: Used for the embeddings run through a clustering algorithm. When hydrating your Redis DB, the documents you want to search over should be embedded with input_type= "search_document" and when you are querying the database, you should set the input_type = "search query". If you want to use the embeddings for a classification or clustering task downstream, you should set input_type= "classification" or "clustering". Args: texts (List[str]): List of text chunks to embed. preprocess (Optional[Callable], optional): Optional preprocessing callable to perform before vectorization. Defaults to None. batch_size (int, optional): Batch size of texts to use when creating embeddings. Defaults to 10. as_buffer (bool, optional): Whether to convert the raw embedding to a byte string. Defaults to False. input_type (str): Specifies the type of input passed to the model. Required for embedding models v3 and higher. Returns: Union[List[List[float]], List[List[int]], List[bytes]]: - If as_buffer=True: Returns a list of bytes objects - If as_buffer=False: - For dtype="float32": Returns a list of lists of floats - For dtype="int8" or "uint8": Returns a list of lists of integers Raises: TypeError: In an invalid input_type is provided. """ input_type = kwargs.pop("input_type", None) if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") if not isinstance(input_type, str): raise TypeError( "Must pass in a str value for cohere embedding input_type.\ See https://docs.cohere.com/reference/embed." ) dtype = kwargs.pop("dtype", self.dtype) # Check if embedding_types was provided and warn user if "embedding_types" in kwargs: warnings.warn( "The 'embedding_types' parameter is not supported in CohereTextVectorizer. " "Please use the 'dtype' parameter instead. Your 'embedding_types' value will be ignored.", UserWarning, stacklevel=2, ) kwargs.pop("embedding_types") # Map dtype to appropriate embedding_type embedding_types = self._get_cohere_embedding_type(dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embed( texts=batch, model=self.model, input_type=input_type, embedding_types=embedding_types, **kwargs, ) # Extract the appropriate embeddings based on embedding_types embed_type = embedding_types[0] if hasattr(response.embeddings, embed_type): batch_embeddings = getattr(response.embeddings, embed_type) else: batch_embeddings = ( response.embeddings ) # Fallback for older API versions embeddings += [ self._process_embedding(embedding, as_buffer, dtype) for embedding in batch_embeddings ] return embeddings
@property def type(self) -> str: return "cohere"