Source code for redisvl.extensions.cache.embeddings.embeddings

"""Embeddings cache implementation for RedisVL."""

from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from redis import Redis
from redis.asyncio import Redis as AsyncRedis

from redisvl.extensions.cache.base import BaseCache
from redisvl.extensions.cache.embeddings.schema import CacheEntry
from redisvl.redis.utils import convert_bytes, hashify


[docs] class EmbeddingsCache(BaseCache): """Embeddings Cache for storing embedding vectors with exact key matching.""" def __init__( self, name: str = "embedcache", ttl: Optional[int] = None, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, ): """Initialize an embeddings cache. Args: name (str): The name of the cache. Defaults to "embedcache". ttl (Optional[int]): The time-to-live for cached embeddings. Defaults to None. redis_client (Optional[Redis]): Redis client instance. Defaults to None. redis_url (str): Redis URL for connection. Defaults to "redis://localhost:6379". connection_kwargs (Dict[str, Any]): Redis connection arguments. Defaults to {}. Raises: ValueError: If vector dimensions are invalid .. code-block:: python cache = EmbeddingsCache( name="my_embeddings_cache", ttl=3600, # 1 hour redis_url="redis://localhost:6379" ) """ super().__init__( name=name, ttl=ttl, redis_client=redis_client, redis_url=redis_url, connection_kwargs=connection_kwargs, ) def _make_entry_id(self, text: str, model_name: str) -> str: """Generate a deterministic entry ID for the given text and model name. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. Returns: str: A deterministic entry ID based on the text and model name. """ return hashify(f"{text}:{model_name}") def _make_cache_key(self, text: str, model_name: str) -> str: """Generate a full Redis key for the given text and model name. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. Returns: str: The full Redis key. """ entry_id = self._make_entry_id(text, model_name) return self._make_key(entry_id) def _prepare_entry_data( self, text: str, model_name: str, embedding: List[float], metadata: Optional[Dict[str, Any]] = None, ) -> Tuple[str, Dict[str, Any]]: """Prepare data for storage in Redis Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. embedding (List[float]): The embedding vector. metadata (Optional[Dict[str, Any]]): Optional metadata. Returns: Tuple[str, Dict[str, Any]]: A tuple of (key, entry_data) """ # Create cache entry with entry_id entry_id = self._make_entry_id(text, model_name) key = self._make_key(entry_id) entry = CacheEntry( entry_id=entry_id, text=text, model_name=model_name, embedding=embedding, metadata=metadata, ) return key, entry.to_dict() def _process_cache_data( self, data: Optional[Dict[str, Any]] ) -> Optional[Dict[str, Any]]: """Process Redis hash data into a cache entry response. Args: data (Optional[Dict[str, Any]]): Raw Redis hash data. Returns: Optional[Dict[str, Any]]: Processed cache entry or None if no data. """ if not data: return None cache_hit = CacheEntry(**convert_bytes(data)) return cache_hit.model_dump(exclude_none=True)
[docs] def get( self, text: str, model_name: str, ) -> Optional[Dict[str, Any]]: """Get embedding by text and model name. Retrieves a cached embedding for the given text and model name. If found, refreshes the TTL of the entry. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. Returns: Optional[Dict[str, Any]]: Embedding cache entry or None if not found. .. code-block:: python embedding_data = cache.get( text="What is machine learning?", model_name="text-embedding-ada-002" ) """ key = self._make_cache_key(text, model_name) return self.get_by_key(key)
[docs] def get_by_key(self, key: str) -> Optional[Dict[str, Any]]: """Get embedding by its full Redis key. Retrieves a cached embedding for the given Redis key. If found, refreshes the TTL of the entry. Args: key (str): The full Redis key for the embedding. Returns: Optional[Dict[str, Any]]: Embedding cache entry or None if not found. .. code-block:: python embedding_data = cache.get_by_key("embedcache:1234567890abcdef") """ client = self._get_redis_client() # Get all fields data = client.hgetall(key) # Refresh TTL if data exists if data: self.expire(key) return self._process_cache_data(data)
[docs] def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]: """Get multiple embeddings by their Redis keys. Efficiently retrieves multiple cached embeddings in a single network roundtrip. If found, refreshes the TTL of each entry. Args: keys (List[str]): List of Redis keys to retrieve. Returns: List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for keys not found. The order matches the input keys order. .. code-block:: python # Get multiple embeddings embedding_data = cache.mget_by_keys([ "embedcache:key1", "embedcache:key2" ]) """ if not keys: return [] client = self._get_redis_client() with client.pipeline(transaction=False) as pipeline: # Queue all hgetall operations for key in keys: pipeline.hgetall(key) results = pipeline.execute() # Process results processed_results = [] for i, result in enumerate(results): if result: # If cache hit, refresh TTL separately self.expire(keys[i]) processed_results.append(self._process_cache_data(result)) return processed_results
[docs] def mget(self, texts: List[str], model_name: str) -> List[Optional[Dict[str, Any]]]: """Get multiple embeddings by their texts and model name. Efficiently retrieves multiple cached embeddings in a single operation. If found, refreshes the TTL of each entry. Args: texts (List[str]): List of text inputs that were embedded. model_name (str): The name of the embedding model. Returns: List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for texts not found. .. code-block:: python # Get multiple embeddings embedding_data = cache.mget( texts=["What is machine learning?", "What is deep learning?"], model_name="text-embedding-ada-002" ) """ if not texts: return [] # Generate keys for each text keys = [self._make_cache_key(text, model_name) for text in texts] # Use the key-based batch operation return self.mget_by_keys(keys)
[docs] def set( self, text: str, model_name: str, embedding: List[float], metadata: Optional[Dict[str, Any]] = None, ttl: Optional[int] = None, ) -> str: """Store an embedding with its text and model name. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. embedding (List[float]): The embedding vector to store. metadata (Optional[Dict[str, Any]]): Optional metadata to store with the embedding. ttl (Optional[int]): Optional TTL override for this specific entry. Returns: str: The Redis key where the embedding was stored. .. code-block:: python key = cache.set( text="What is machine learning?", model_name="text-embedding-ada-002", embedding=[0.1, 0.2, 0.3, ...], metadata={"source": "user_query"} ) """ # Prepare data key, cache_entry = self._prepare_entry_data( text, model_name, embedding, metadata ) # Store in Redis client = self._get_redis_client() client.hset(name=key, mapping=cache_entry) # type: ignore # Set TTL if specified self.expire(key, ttl) return key
[docs] def mset( self, items: List[Dict[str, Any]], ttl: Optional[int] = None, ) -> List[str]: """Store multiple embeddings in a batch operation. Each item in the input list should be a dictionary with the following fields: - 'text': The text input that was embedded - 'model_name': The name of the embedding model - 'embedding': The embedding vector - 'metadata': Optional metadata to store with the embedding Args: items: List of dictionaries, each containing text, model_name, embedding, and optional metadata. ttl: Optional TTL override for these entries. Returns: List[str]: List of Redis keys where the embeddings were stored. .. code-block:: python # Store multiple embeddings keys = cache.mset([ { "text": "What is ML?", "model_name": "text-embedding-ada-002", "embedding": [0.1, 0.2, 0.3], "metadata": {"source": "user"} }, { "text": "What is AI?", "model_name": "text-embedding-ada-002", "embedding": [0.4, 0.5, 0.6], "metadata": {"source": "docs"} } ]) """ if not items: return [] client = self._get_redis_client() keys = [] with client.pipeline(transaction=False) as pipeline: # Process all entries for item in items: # Prepare and store key, cache_entry = self._prepare_entry_data(**item) keys.append(key) pipeline.hset(name=key, mapping=cache_entry) # type: ignore pipeline.execute() # Set TTLs for key in keys: self.expire(key, ttl) return keys
[docs] def exists(self, text: str, model_name: str) -> bool: """Check if an embedding exists for the given text and model. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. Returns: bool: True if the embedding exists in the cache, False otherwise. .. code-block:: python if cache.exists("What is machine learning?", "text-embedding-ada-002"): print("Embedding is in cache") """ client = self._get_redis_client() key = self._make_cache_key(text, model_name) return bool(client.exists(key))
[docs] def exists_by_key(self, key: str) -> bool: """Check if an embedding exists for the given Redis key. Args: key (str): The full Redis key for the embedding. Returns: bool: True if the embedding exists in the cache, False otherwise. .. code-block:: python if cache.exists_by_key("embedcache:1234567890abcdef"): print("Embedding is in cache") """ client = self._get_redis_client() return bool(client.exists(key))
[docs] def mexists_by_keys(self, keys: List[str]) -> List[bool]: """Check if multiple embeddings exist by their Redis keys. Efficiently checks existence of multiple keys in a single operation. Args: keys (List[str]): List of Redis keys to check. Returns: List[bool]: List of boolean values indicating whether each key exists. The order matches the input keys order. .. code-block:: python # Check if multiple keys exist exists_results = cache.mexists_by_keys(["embedcache:key1", "embedcache:key2"]) """ if not keys: return [] client = self._get_redis_client() with client.pipeline(transaction=False) as pipeline: # Queue all exists operations for key in keys: pipeline.exists(key) results = pipeline.execute() # Convert to boolean values return [bool(result) for result in results]
[docs] def mexists(self, texts: List[str], model_name: str) -> List[bool]: """Check if multiple embeddings exist by their texts and model name. Efficiently checks existence of multiple embeddings in a single operation. Args: texts (List[str]): List of text inputs that were embedded. model_name (str): The name of the embedding model. Returns: List[bool]: List of boolean values indicating whether each embedding exists. .. code-block:: python # Check if multiple embeddings exist exists_results = cache.mexists( texts=["What is machine learning?", "What is deep learning?"], model_name="text-embedding-ada-002" ) """ if not texts: return [] # Generate keys for each text keys = [self._make_cache_key(text, model_name) for text in texts] # Use the key-based batch operation return self.mexists_by_keys(keys)
[docs] def drop(self, text: str, model_name: str) -> None: """Remove an embedding from the cache. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. .. code-block:: python cache.drop( text="What is machine learning?", model_name="text-embedding-ada-002" ) """ key = self._make_cache_key(text, model_name) self.drop_by_key(key)
[docs] def drop_by_key(self, key: str) -> None: """Remove an embedding from the cache by its Redis key. Args: key (str): The full Redis key for the embedding. .. code-block:: python cache.drop_by_key("embedcache:1234567890abcdef") """ client = self._get_redis_client() client.delete(key)
[docs] def mdrop_by_keys(self, keys: List[str]) -> None: """Remove multiple embeddings from the cache by their Redis keys. Efficiently removes multiple embeddings in a single operation. Args: keys (List[str]): List of Redis keys to remove. .. code-block:: python # Remove multiple embeddings cache.mdrop_by_keys(["embedcache:key1", "embedcache:key2"]) """ if not keys: return client = self._get_redis_client() with client.pipeline(transaction=False) as pipeline: for key in keys: pipeline.delete(key) pipeline.execute()
[docs] def mdrop(self, texts: List[str], model_name: str) -> None: """Remove multiple embeddings from the cache by their texts and model name. Efficiently removes multiple embeddings in a single operation. Args: texts (List[str]): List of text inputs that were embedded. model_name (str): The name of the embedding model. .. code-block:: python # Remove multiple embeddings cache.mdrop( texts=["What is machine learning?", "What is deep learning?"], model_name="text-embedding-ada-002" ) """ if not texts: return # Generate keys for each text keys = [self._make_cache_key(text, model_name) for text in texts] # Use the key-based batch operation self.mdrop_by_keys(keys)
[docs] async def aget( self, text: str, model_name: str, ) -> Optional[Dict[str, Any]]: """Async get embedding by text and model name. Asynchronously retrieves a cached embedding for the given text and model name. If found, refreshes the TTL of the entry. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. Returns: Optional[Dict[str, Any]]: Embedding cache entry or None if not found. .. code-block:: python embedding_data = await cache.aget( text="What is machine learning?", model_name="text-embedding-ada-002" ) """ key = self._make_cache_key(text, model_name) return await self.aget_by_key(key)
[docs] async def aget_by_key(self, key: str) -> Optional[Dict[str, Any]]: """Async get embedding by its full Redis key. Asynchronously retrieves a cached embedding for the given Redis key. If found, refreshes the TTL of the entry. Args: key (str): The full Redis key for the embedding. Returns: Optional[Dict[str, Any]]: Embedding cache entry or None if not found. .. code-block:: python embedding_data = await cache.aget_by_key("embedcache:1234567890abcdef") """ client = await self._get_async_redis_client() # Get all fields data = await client.hgetall(key) # Refresh TTL if data exists if data: await self.aexpire(key) return self._process_cache_data(data)
[docs] async def amget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]: """Async get multiple embeddings by their Redis keys. Asynchronously retrieves multiple cached embeddings in a single network roundtrip. If found, refreshes the TTL of each entry. Args: keys (List[str]): List of Redis keys to retrieve. Returns: List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for keys not found. The order matches the input keys order. .. code-block:: python # Get multiple embeddings asynchronously embedding_data = await cache.amget_by_keys([ "embedcache:key1", "embedcache:key2" ]) """ if not keys: return [] client = await self._get_async_redis_client() # Use pipeline only for retrieval async with client.pipeline(transaction=False) as pipeline: # Queue all hgetall operations for key in keys: await pipeline.hgetall(key) results = await pipeline.execute() # Process results and refresh TTLs separately processed_results = [] for i, result in enumerate(results): if result: # If cache hit, refresh TTL await self.aexpire(keys[i]) processed_results.append(self._process_cache_data(result)) return processed_results
[docs] async def amget( self, texts: List[str], model_name: str ) -> List[Optional[Dict[str, Any]]]: """Async get multiple embeddings by their texts and model name. Asynchronously retrieves multiple cached embeddings in a single operation. If found, refreshes the TTL of each entry. Args: texts (List[str]): List of text inputs that were embedded. model_name (str): The name of the embedding model. Returns: List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for texts not found. .. code-block:: python # Get multiple embeddings asynchronously embedding_data = await cache.amget( texts=["What is machine learning?", "What is deep learning?"], model_name="text-embedding-ada-002" ) """ if not texts: return [] # Generate keys for each text keys = [self._make_cache_key(text, model_name) for text in texts] # Use the key-based batch operation return await self.amget_by_keys(keys)
[docs] async def aset( self, text: str, model_name: str, embedding: List[float], metadata: Optional[Dict[str, Any]] = None, ttl: Optional[int] = None, ) -> str: """Async store an embedding with its text and model name. Asynchronously stores an embedding with its text and model name. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. embedding (List[float]): The embedding vector to store. metadata (Optional[Dict[str, Any]]): Optional metadata to store with the embedding. ttl (Optional[int]): Optional TTL override for this specific entry. Returns: str: The Redis key where the embedding was stored. .. code-block:: python key = await cache.aset( text="What is machine learning?", model_name="text-embedding-ada-002", embedding=[0.1, 0.2, 0.3, ...], metadata={"source": "user_query"} ) """ # Prepare data key, cache_entry = self._prepare_entry_data( text, model_name, embedding, metadata ) # Store in Redis client = await self._get_async_redis_client() await client.hset(name=key, mapping=cache_entry) # type: ignore # Set TTL if specified await self.aexpire(key, ttl) return key
[docs] async def amset( self, items: List[Dict[str, Any]], ttl: Optional[int] = None, ) -> List[str]: """Async store multiple embeddings in a batch operation. Each item in the input list should be a dictionary with the following fields: - 'text': The text input that was embedded - 'model_name': The name of the embedding model - 'embedding': The embedding vector - 'metadata': Optional metadata to store with the embedding Args: items: List of dictionaries, each containing text, model_name, embedding, and optional metadata. ttl: Optional TTL override for these entries. Returns: List[str]: List of Redis keys where the embeddings were stored. .. code-block:: python # Store multiple embeddings asynchronously keys = await cache.amset([ { "text": "What is ML?", "model_name": "text-embedding-ada-002", "embedding": [0.1, 0.2, 0.3], "metadata": {"source": "user"} }, { "text": "What is AI?", "model_name": "text-embedding-ada-002", "embedding": [0.4, 0.5, 0.6], "metadata": {"source": "docs"} } ]) """ if not items: return [] client = await self._get_async_redis_client() keys = [] async with client.pipeline(transaction=False) as pipeline: # Process all entries for item in items: # Prepare and store key, cache_entry = self._prepare_entry_data(**item) keys.append(key) await pipeline.hset(name=key, mapping=cache_entry) # type: ignore await pipeline.execute() # Set TTLs for key in keys: await self.aexpire(key, ttl) return keys
[docs] async def amexists_by_keys(self, keys: List[str]) -> List[bool]: """Async check if multiple embeddings exist by their Redis keys. Asynchronously checks existence of multiple keys in a single operation. Args: keys (List[str]): List of Redis keys to check. Returns: List[bool]: List of boolean values indicating whether each key exists. The order matches the input keys order. .. code-block:: python # Check if multiple keys exist asynchronously exists_results = await cache.amexists_by_keys(["embedcache:key1", "embedcache:key2"]) """ if not keys: return [] client = await self._get_async_redis_client() async with client.pipeline(transaction=False) as pipeline: # Queue all exists operations for key in keys: await pipeline.exists(key) results = await pipeline.execute() # Convert to boolean values return [bool(result) for result in results]
[docs] async def amexists(self, texts: List[str], model_name: str) -> List[bool]: """Async check if multiple embeddings exist by their texts and model name. Asynchronously checks existence of multiple embeddings in a single operation. Args: texts (List[str]): List of text inputs that were embedded. model_name (str): The name of the embedding model. Returns: List[bool]: List of boolean values indicating whether each embedding exists. .. code-block:: python # Check if multiple embeddings exist asynchronously exists_results = await cache.amexists( texts=["What is machine learning?", "What is deep learning?"], model_name="text-embedding-ada-002" ) """ if not texts: return [] # Generate keys for each text keys = [self._make_cache_key(text, model_name) for text in texts] # Use the key-based batch operation return await self.amexists_by_keys(keys)
[docs] async def amdrop_by_keys(self, keys: List[str]) -> None: """Async remove multiple embeddings from the cache by their Redis keys. Asynchronously removes multiple embeddings in a single operation. Args: keys (List[str]): List of Redis keys to remove. .. code-block:: python # Remove multiple embeddings asynchronously await cache.amdrop_by_keys(["embedcache:key1", "embedcache:key2"]) """ if not keys: return client = await self._get_async_redis_client() await client.delete(*keys)
[docs] async def amdrop(self, texts: List[str], model_name: str) -> None: """Async remove multiple embeddings from the cache by their texts and model name. Asynchronously removes multiple embeddings in a single operation. Args: texts (List[str]): List of text inputs that were embedded. model_name (str): The name of the embedding model. .. code-block:: python # Remove multiple embeddings asynchronously await cache.amdrop( texts=["What is machine learning?", "What is deep learning?"], model_name="text-embedding-ada-002" ) """ if not texts: return # Generate keys for each text keys = [self._make_cache_key(text, model_name) for text in texts] # Use the key-based batch operation await self.amdrop_by_keys(keys)
[docs] async def aexists(self, text: str, model_name: str) -> bool: """Async check if an embedding exists. Asynchronously checks if an embedding exists for the given text and model. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. Returns: bool: True if the embedding exists in the cache, False otherwise. .. code-block:: python if await cache.aexists("What is machine learning?", "text-embedding-ada-002"): print("Embedding is in cache") """ key = self._make_cache_key(text, model_name) return await self.aexists_by_key(key)
[docs] async def aexists_by_key(self, key: str) -> bool: """Async check if an embedding exists for the given Redis key. Asynchronously checks if an embedding exists for the given Redis key. Args: key (str): The full Redis key for the embedding. Returns: bool: True if the embedding exists in the cache, False otherwise. .. code-block:: python if await cache.aexists_by_key("embedcache:1234567890abcdef"): print("Embedding is in cache") """ client = await self._get_async_redis_client() return bool(await client.exists(key))
[docs] async def adrop(self, text: str, model_name: str) -> None: """Async remove an embedding from the cache. Asynchronously removes an embedding from the cache. Args: text (str): The text input that was embedded. model_name (str): The name of the embedding model. .. code-block:: python await cache.adrop( text="What is machine learning?", model_name="text-embedding-ada-002" ) """ key = self._make_cache_key(text, model_name) await self.adrop_by_key(key)
[docs] async def adrop_by_key(self, key: str) -> None: """Async remove an embedding from the cache by its Redis key. Asynchronously removes an embedding from the cache by its Redis key. Args: key (str): The full Redis key for the embedding. .. code-block:: python await cache.adrop_by_key("embedcache:1234567890abcdef") """ client = await self._get_async_redis_client() await client.delete(key)