import os
from typing import TYPE_CHECKING, Dict, List, Optional
from pydantic import ConfigDict
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type
if TYPE_CHECKING:
from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache
from redisvl.utils.utils import deprecated_argument
from redisvl.utils.vectorize.base import BaseVectorizer
# ignore that openai isn't imported
# mypy: disable-error-code="name-defined"
[docs]
class OpenAITextVectorizer(BaseVectorizer):
"""The OpenAITextVectorizer class utilizes OpenAI's API to generate
embeddings for text data.
This vectorizer is designed to interact with OpenAI's embeddings API,
requiring an API key for authentication. The key can be provided directly
in the `api_config` dictionary or through the `OPENAI_API_KEY` environment
variable. Users must obtain an API key from OpenAI's website
(https://api.openai.com/). Additionally, the `openai` python client must be
installed with `pip install openai>=1.13.0`.
The vectorizer supports both synchronous and asynchronous operations,
allowing for batch processing of texts and flexibility in handling
preprocessing tasks.
You can optionally enable caching to improve performance when generating
embeddings for repeated text inputs.
.. code-block:: python
# Basic usage with OpenAI embeddings
vectorizer = OpenAITextVectorizer(
model="text-embedding-ada-002",
api_config={"api_key": "your_api_key"} # OR set OPENAI_API_KEY in your env
)
embedding = vectorizer.embed("Hello, world!")
# With caching enabled
from redisvl.extensions.cache.embeddings import EmbeddingsCache
cache = EmbeddingsCache(name="openai_embeddings_cache")
vectorizer = OpenAITextVectorizer(
model="text-embedding-ada-002",
api_config={"api_key": "your_api_key"},
cache=cache
)
# First call will compute and cache the embedding
embedding1 = vectorizer.embed("Hello, world!")
# Second call will retrieve from cache
embedding2 = vectorizer.embed("Hello, world!")
# Asynchronous batch embedding of multiple texts
embeddings = await vectorizer.aembed_many(
["Hello, world!", "How are you?"],
batch_size=2
)
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(
self,
model: str = "text-embedding-ada-002",
api_config: Optional[Dict] = None,
dtype: str = "float32",
cache: Optional["EmbeddingsCache"] = None,
**kwargs,
):
"""Initialize the OpenAI vectorizer.
Args:
model (str): Model to use for embedding. Defaults to
'text-embedding-ada-002'.
api_config (Optional[Dict], optional): Dictionary containing the
API key and any additional OpenAI API options. Defaults to None.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.
cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for
better performance with repeated texts. Defaults to None.
Raises:
ImportError: If the openai library is not installed.
ValueError: If the OpenAI API key is not provided.
ValueError: If an invalid dtype is provided.
"""
super().__init__(model=model, dtype=dtype, cache=cache)
# Initialize clients and set up the model
self._setup(api_config, **kwargs)
def _setup(self, api_config: Optional[Dict], **kwargs):
"""Set up the OpenAI clients and determine the embedding dimensions."""
# Initialize clients
self._initialize_clients(api_config, **kwargs)
# Set model dimensions after client initialization
self.dims = self._set_model_dims()
def _initialize_clients(self, api_config: Optional[Dict], **kwargs):
"""
Setup the OpenAI clients using the provided API key or an
environment variable.
Args:
api_config: Dictionary with API configuration options
**kwargs: Additional arguments to pass to OpenAI clients
Raises:
ImportError: If the openai library is not installed
ValueError: If no API key is provided
"""
if api_config is None:
api_config = {}
# Dynamic import of the openai module
try:
from openai import AsyncOpenAI, OpenAI
except ImportError:
raise ImportError(
"OpenAI vectorizer requires the openai library. "
"Please install with `pip install openai>=1.13.0`"
)
api_key = (
api_config.pop("api_key") if api_config else os.getenv("OPENAI_API_KEY")
)
if not api_key:
raise ValueError(
"OpenAI API key is required. "
"Provide it in api_config or set the OPENAI_API_KEY environment variable."
)
self._client = OpenAI(api_key=api_key, **api_config, **kwargs)
self._aclient = AsyncOpenAI(api_key=api_key, **api_config, **kwargs)
def _set_model_dims(self) -> int:
"""
Determine the dimensionality of the embedding model by making a test call.
Returns:
int: Dimensionality of the embedding model
Raises:
ValueError: If embedding dimensions cannot be determined
"""
try:
# Use the parent embed() method which handles caching
embedding = self._embed("dimension check")
return len(embedding)
except (KeyError, IndexError) as ke:
raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}")
except Exception as e: # pylint: disable=broad-except
# fall back (TODO get more specific)
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
def _embed(self, text: str, **kwargs) -> List[float]:
"""Generate a vector embedding for a single text using the OpenAI API.
Args:
text: Text to embed
**kwargs: Additional parameters to pass to the OpenAI API
Returns:
List[float]: Vector embedding as a list of floats
Raises:
TypeError: If text is not a string
ValueError: If embedding fails
"""
if not isinstance(text, str):
raise TypeError("Must pass in a str value to embed.")
try:
result = self._client.embeddings.create(
input=[text], model=self.model, **kwargs
)
return result.data[0].embedding
except Exception as e:
raise ValueError(f"Embedding text failed: {e}")
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
def _embed_many(
self, texts: List[str], batch_size: int = 10, **kwargs
) -> List[List[float]]:
"""Generate vector embeddings for a batch of texts using the OpenAI API.
Args:
texts: List of texts to embed
batch_size: Number of texts to process in each API call
**kwargs: Additional parameters to pass to the OpenAI API
Returns:
List[List[float]]: List of vector embeddings as lists of floats
Raises:
TypeError: If texts is not a list of strings
ValueError: If embedding fails
"""
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if texts and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
embeddings: List = []
for batch in self.batchify(texts, batch_size):
try:
response = self._client.embeddings.create(
input=batch, model=self.model, **kwargs
)
embeddings += [r.embedding for r in response.data]
except Exception as e:
raise ValueError(f"Embedding texts failed: {e}")
return embeddings
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
async def _aembed(self, text: str, **kwargs) -> List[float]:
"""Asynchronously generate a vector embedding for a single text using the OpenAI API.
Args:
text: Text to embed
**kwargs: Additional parameters to pass to the OpenAI API
Returns:
List[float]: Vector embedding as a list of floats
Raises:
TypeError: If text is not a string
ValueError: If embedding fails
"""
if not isinstance(text, str):
raise TypeError("Must pass in a str value to embed.")
try:
result = await self._aclient.embeddings.create(
input=[text], model=self.model, **kwargs
)
return result.data[0].embedding
except Exception as e:
raise ValueError(f"Embedding text failed: {e}")
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
async def _aembed_many(
self, texts: List[str], batch_size: int = 10, **kwargs
) -> List[List[float]]:
"""Asynchronously generate vector embeddings for a batch of texts using the OpenAI API.
Args:
texts: List of texts to embed
batch_size: Number of texts to process in each API call
**kwargs: Additional parameters to pass to the OpenAI API
Returns:
List[List[float]]: List of vector embeddings as lists of floats
Raises:
TypeError: If texts is not a list of strings
ValueError: If embedding fails
"""
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if texts and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
embeddings: List = []
for batch in self.batchify(texts, batch_size):
try:
response = await self._aclient.embeddings.create(
input=batch, model=self.model, **kwargs
)
embeddings += [r.embedding for r in response.data]
except Exception as e:
raise ValueError(f"Embedding texts failed: {e}")
return embeddings
@property
def type(self) -> str:
return "openai"