Source code for redisvl.utils.vectorize.text.vertexai

import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from pydantic import ConfigDict
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

if TYPE_CHECKING:
    from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache

from redisvl.utils.utils import deprecated_argument
from redisvl.utils.vectorize.base import BaseVectorizer


[docs] class VertexAITextVectorizer(BaseVectorizer): """The VertexAITextVectorizer uses Google's VertexAI Palm 2 embedding model API to create text embeddings. This vectorizer is tailored for use in environments where integration with Google Cloud Platform (GCP) services is a key requirement. Utilizing this vectorizer requires an active GCP project and location (region), along with appropriate application credentials. These can be provided through the `api_config` dictionary or set the GOOGLE_APPLICATION_CREDENTIALS env var. Additionally, the vertexai python client must be installed with `pip install google-cloud-aiplatform>=1.26`. You can optionally enable caching to improve performance when generating embeddings for repeated text inputs. .. code-block:: python # Basic usage vectorizer = VertexAITextVectorizer( model="textembedding-gecko", api_config={ "project_id": "your_gcp_project_id", # OR set GCP_PROJECT_ID "location": "your_gcp_location", # OR set GCP_LOCATION }) embedding = vectorizer.embed("Hello, world!") # With caching enabled from redisvl.extensions.cache.embeddings import EmbeddingsCache cache = EmbeddingsCache(name="vertexai_embeddings_cache") vectorizer = VertexAITextVectorizer( model="textembedding-gecko", api_config={ "project_id": "your_gcp_project_id", "location": "your_gcp_location", }, cache=cache ) # First call will compute and cache the embedding embedding1 = vectorizer.embed("Hello, world!") # Second call will retrieve from cache embedding2 = vectorizer.embed("Hello, world!") # Batch embedding of multiple texts embeddings = vectorizer.embed_many( ["Hello, world!", "Goodbye, world!"], batch_size=2 ) """ model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "textembedding-gecko", api_config: Optional[Dict] = None, dtype: str = "float32", cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the VertexAI vectorizer. Args: model (str): Model to use for embedding. Defaults to 'textembedding-gecko'. api_config (Optional[Dict], optional): Dictionary containing the API config details. Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for better performance with repeated texts. Defaults to None. Raises: ImportError: If the google-cloud-aiplatform library is not installed. ValueError: If the API key is not provided. ValueError: If an invalid dtype is provided. """ super().__init__(model=model, dtype=dtype, cache=cache) # Initialize client and set up the model self._setup(api_config, **kwargs) def _setup(self, api_config: Optional[Dict], **kwargs): """Set up the VertexAI client and determine the embedding dimensions.""" # Initialize client self._initialize_client(api_config, **kwargs) # Set model dimensions after initialization self.dims = self._set_model_dims() def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the VertexAI client using the provided config options or environment variables. Args: api_config: Dictionary with GCP configuration options **kwargs: Additional arguments for initialization Raises: ImportError: If the google-cloud-aiplatform library is not installed ValueError: If required parameters are not provided """ # Fetch the project_id and location from api_config or environment variables project_id = ( api_config.get("project_id") if api_config else os.getenv("GCP_PROJECT_ID") ) location = ( api_config.get("location") if api_config else os.getenv("GCP_LOCATION") ) if not project_id: raise ValueError( "Missing project_id. " "Provide the id in the api_config with key 'project_id' " "or set the GCP_PROJECT_ID environment variable." ) if not location: raise ValueError( "Missing location. " "Provide the location (region) in the api_config with key 'location' " "or set the GCP_LOCATION environment variable." ) # Check for credentials credentials = api_config.get("credentials") if api_config else None try: import vertexai from vertexai.language_models import TextEmbeddingModel vertexai.init( project=project_id, location=location, credentials=credentials ) except ImportError: raise ImportError( "VertexAI vectorizer requires the google-cloud-aiplatform library. " "Please install with `pip install google-cloud-aiplatform>=1.26`" ) # Store client as a regular attribute instead of PrivateAttr self._client = TextEmbeddingModel.from_pretrained(self.model) def _set_model_dims(self) -> int: """ Determine the dimensionality of the embedding model by making a test call. Returns: int: Dimensionality of the embedding model Raises: ValueError: If embedding dimensions cannot be determined """ try: # Call the protected _embed method to avoid caching this test embedding embedding = self._embed("dimension check") return len(embedding) except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the VertexAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) def _embed(self, text: str, **kwargs) -> List[float]: """ Generate a vector embedding for a single text using the VertexAI API. Args: text: Text to embed **kwargs: Additional parameters to pass to the VertexAI API Returns: List[float]: Vector embedding as a list of floats Raises: TypeError: If text is not a string ValueError: If embedding fails """ if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") try: result = self._client.get_embeddings([text], **kwargs) return result[0].values except Exception as e: raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) def _embed_many( self, texts: List[str], batch_size: int = 10, **kwargs ) -> List[List[float]]: """ Generate vector embeddings for a batch of texts using the VertexAI API. Args: texts: List of texts to embed batch_size: Number of texts to process in each API call **kwargs: Additional parameters to pass to the VertexAI API Returns: List[List[float]]: List of vector embeddings as lists of floats Raises: TypeError: If texts is not a list of strings ValueError: If embedding fails """ if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") if texts and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") try: embeddings: List = [] for batch in self.batchify(texts, batch_size): response = self._client.get_embeddings(batch, **kwargs) embeddings.extend([r.values for r in response]) return embeddings except Exception as e: raise ValueError(f"Embedding texts failed: {e}") @property def type(self) -> str: return "vertexai"