import os
from functools import cached_property
from typing import TYPE_CHECKING, Any
from pydantic import ConfigDict
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type
if TYPE_CHECKING:
from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache
from redisvl.utils.vectorize.base import BaseVectorizer
[docs]
class VertexAIVectorizer(BaseVectorizer):
"""The VertexAIVectorizer uses Google's VertexAI embedding model
API to create embeddings.
This vectorizer is tailored for use in
environments where integration with Google Cloud Platform (GCP) services is
a key requirement.
Utilizing this vectorizer requires an active GCP project and location
(region), along with appropriate application credentials. These can be
provided through the `api_config` dictionary or set the GOOGLE_APPLICATION_CREDENTIALS
env var. Additionally, the vertexai python client must be
installed with `pip install google-cloud-aiplatform>=1.26`.
You can optionally enable caching to improve performance when generating
embeddings for repeated inputs.
.. code-block:: python
# Basic usage
vectorizer = VertexAIVectorizer(
model="textembedding-gecko",
api_config={
"project_id": "your_gcp_project_id", # OR set GCP_PROJECT_ID
"location": "your_gcp_location", # OR set GCP_LOCATION
})
embedding = vectorizer.embed("Hello, world!")
# With caching enabled
from redisvl.extensions.cache.embeddings import EmbeddingsCache
cache = EmbeddingsCache(name="vertexai_embeddings_cache")
vectorizer = VertexAIVectorizer(
model="textembedding-gecko",
api_config={
"project_id": "your_gcp_project_id",
"location": "your_gcp_location",
},
cache=cache
)
# First call will compute and cache the embedding
embedding1 = vectorizer.embed("Hello, world!")
# Second call will retrieve from cache
embedding2 = vectorizer.embed("Hello, world!")
# Batch embedding of multiple texts
embeddings = vectorizer.embed_many(
["Hello, world!", "Goodbye, world!"],
batch_size=2
)
# Multimodal usage
from vertexai.vision_models import Image, Video
vectorizer = VertexAIVectorizer(
model="multimodalembedding@001",
api_config={
"project_id": "your_gcp_project_id", # OR set GCP_PROJECT_ID
"location": "your_gcp_location", # OR set GCP_LOCATION
}
)
text_embedding = vectorizer.embed("Hello, world!")
image_embedding = vectorizer.embed(Image.load_from_file("path/to/your/image.jpg"))
video_embedding = vectorizer.embed(Video.load_from_file("path/to/your/video.mp4"))
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(
self,
model: str = "textembedding-gecko",
api_config: dict[str, Any] | None = None,
dtype: str = "float32",
cache: "EmbeddingsCache | None" = None,
**kwargs,
):
"""Initialize the VertexAI vectorizer.
Args:
model (str): Model to use for embedding. Defaults to
'textembedding-gecko'.
api_config (Optional[Dict], optional): Dictionary containing the
API config details. Defaults to None.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.
cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for
better performance with repeated texts. Defaults to None.
Raises:
ImportError: If the google-cloud-aiplatform library is not installed.
ValueError: If the API key is not provided.
ValueError: If an invalid dtype is provided.
"""
super().__init__(model=model, dtype=dtype, cache=cache)
# Initialize client and set up the model
self._setup(api_config, **kwargs)
@property
def is_multimodal(self) -> bool:
"""Whether a multimodal model has been configured."""
return "multimodal" in self.model
@cached_property
def _client(self):
"""Get the appropriate client based on the model type."""
if self.is_multimodal:
from vertexai.vision_models import MultiModalEmbeddingModel
return MultiModalEmbeddingModel.from_pretrained(self.model)
from vertexai.language_models import TextEmbeddingModel
return TextEmbeddingModel.from_pretrained(self.model)
[docs]
def embed_image(self, image_path: str, **kwargs) -> list[float] | bytes:
"""Embed an image (from its path on disk) using a VertexAI multimodal model."""
if not self.is_multimodal:
raise ValueError("Cannot embed image with a non-multimodal model.")
from vertexai.vision_models import Image
return self.embed(Image.load_from_file(image_path), **kwargs)
[docs]
def embed_video(self, video_path: str, **kwargs) -> list[float] | bytes:
"""Embed a video (from its path on disk) using a VertexAI multimodal model."""
if not self.is_multimodal:
raise ValueError("Cannot embed video with a non-multimodal model.")
from vertexai.vision_models import Video
return self.embed(Video.load_from_file(video_path), **kwargs)
def _setup(self, api_config: dict[str, Any] | None, **kwargs):
"""Set up the VertexAI client and determine the embedding dimensions."""
# Initialize client
self._initialize_client(api_config, **kwargs)
# Set model dimensions after initialization
self.dims = self._set_model_dims()
def _initialize_client(self, api_config: dict[str, Any] | None, **kwargs):
"""
Setup the VertexAI client using the provided config options or
environment variables.
Args:
api_config: Dictionary with GCP configuration options
**kwargs: Additional arguments for initialization
Raises:
ImportError: If the google-cloud-aiplatform library is not installed
ValueError: If required parameters are not provided
"""
# Fetch the project_id and location from api_config or environment variables
project_id = (
api_config.get("project_id") if api_config else os.getenv("GCP_PROJECT_ID")
)
location = (
api_config.get("location") if api_config else os.getenv("GCP_LOCATION")
)
if not project_id:
raise ValueError(
"Missing project_id. "
"Provide the id in the api_config with key 'project_id' "
"or set the GCP_PROJECT_ID environment variable."
)
if not location:
raise ValueError(
"Missing location. "
"Provide the location (region) in the api_config with key 'location' "
"or set the GCP_LOCATION environment variable."
)
# Check for credentials
credentials = api_config.get("credentials") if api_config else None
try:
import vertexai
vertexai.init(
project=project_id, location=location, credentials=credentials
)
except ImportError:
raise ImportError(
"VertexAI vectorizer requires the google-cloud-aiplatform library. "
"Please install with `pip install google-cloud-aiplatform>=1.26`"
)
def _set_model_dims(self) -> int:
"""
Determine the dimensionality of the embedding model by making a test call.
Returns:
int: Dimensionality of the embedding model
Raises:
ValueError: If embedding dimensions cannot be determined
"""
try:
# Call the protected _embed method to avoid caching this test embedding
embedding = self._embed("dimension check")
return len(embedding)
except (KeyError, IndexError) as ke:
raise ValueError(f"Unexpected response from the VertexAI API: {str(ke)}")
except Exception as e: # pylint: disable=broad-except
# fall back (TODO get more specific)
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type((TypeError, ValueError)),
)
def _embed(self, content: Any, **kwargs) -> list[float]:
"""
Generate a vector embedding for a single input using the VertexAI API.
Args:
content: Input to embed
**kwargs: Additional parameters to pass to the VertexAI API
Returns:
List[float]: Vector embedding as a list of floats
Raises:
ValueError: If embedding fails
"""
try:
if self.is_multimodal:
from vertexai.vision_models import Image, Video
if isinstance(content, str):
result = self._client.get_embeddings(
contextual_text=content,
**kwargs,
)
if result.text_embedding is None:
raise ValueError("No text embedding returned from VertexAI.")
return result.text_embedding
elif isinstance(content, Image):
result = self._client.get_embeddings(
image=content,
**kwargs,
)
if result.image_embedding is None:
raise ValueError("No image embedding returned from VertexAI.")
return result.image_embedding
elif isinstance(content, Video):
result = self._client.get_embeddings(
video=content,
**kwargs,
)
if result.video_embeddings is None:
raise ValueError("No video embedding returned from VertexAI.")
return result.video_embeddings[0].embedding
else:
raise TypeError(
"Invalid input type for multimodal embedding. "
"Must be str, Image, or Video."
)
else:
return self._client.get_embeddings([content], **kwargs)[0].values
except Exception as e:
raise ValueError(f"Embedding input failed: {e}")
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type((TypeError, ValueError)),
)
def _embed_many(
self, contents: list[str], batch_size: int = 10, **kwargs
) -> list[list[float]]:
"""
Generate vector embeddings for a batch of texts using the VertexAI API.
Args:
contents: List of texts to embed
batch_size: Number of texts to process in each API call
**kwargs: Additional parameters to pass to the VertexAI API
Returns:
List[List[float]]: List of vector embeddings as lists of floats
Raises:
ValueError: If embedding fails
"""
if self.is_multimodal:
raise NotImplementedError(
"Batch embedding is not supported for multimodal models with VertexAI."
)
if not isinstance(contents, list):
raise TypeError("Must pass in a list of str values to embed.")
if contents and not isinstance(contents[0], str):
raise TypeError("Must pass in a list of str values to embed.")
try:
embeddings: list[Any] = []
for batch in self.batchify(contents, batch_size):
response = self._client.get_embeddings(batch, **kwargs)
embeddings.extend([r.values for r in response])
return embeddings
except Exception as e:
raise ValueError(f"Embedding texts failed: {e}")
def _serialize_for_cache(self, content: Any) -> bytes | str:
"""Convert content to a cacheable format."""
from vertexai.vision_models import Image, Video
if isinstance(content, Image):
return content._image_bytes
elif isinstance(content, Video):
return content._video_bytes
return super()._serialize_for_cache(content)
@property
def type(self) -> str:
return "vertexai"