Source code for redisvl.extensions.router.schema
import warnings
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing_extensions import Annotated
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redisvl.schema import IndexSchema
[docs]
class Route(BaseModel):
"""Model representing a routing path with associated metadata and thresholds."""
name: str
"""The name of the route."""
references: List[str]
"""List of reference phrases for the route."""
metadata: Dict[str, Any] = Field(default={})
"""Metadata associated with the route."""
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=2)] = 0.5
"""Distance threshold for matching the route."""
@field_validator("name")
@classmethod
def name_must_not_be_empty(cls, v):
if not v or not v.strip():
raise ValueError("Route name must not be empty")
return v
@field_validator("references")
@classmethod
def references_must_not_be_empty(cls, v):
if not v:
raise ValueError("References must not be empty")
if any(not ref.strip() for ref in v):
raise ValueError("All references must be non-empty strings")
return v
[docs]
class RouteMatch(BaseModel):
"""Model representing a matched route with distance information."""
name: Optional[str] = None
"""The matched route name."""
distance: Optional[float] = Field(default=None)
"""The vector distance between the statement and the matched route."""
[docs]
class DistanceAggregationMethod(Enum):
"""Enumeration for distance aggregation methods."""
avg = "avg"
"""Compute the average of the vector distances."""
min = "min"
"""Compute the minimum of the vector distances."""
sum = "sum"
"""Compute the sum of the vector distances."""
[docs]
class RoutingConfig(BaseModel):
"""Configuration for routing behavior."""
"""The maximum number of top matches to return."""
max_k: Annotated[int, Field(strict=True, gt=0)] = 1
"""Aggregation method to use to classify queries."""
aggregation_method: DistanceAggregationMethod = Field(
default=DistanceAggregationMethod.avg
)
model_config = ConfigDict(extra="ignore")
@model_validator(mode="before")
@classmethod
def remove_distance_threshold(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "distance_threshold" in values:
warnings.warn(
"The 'distance_threshold' field is deprecated and will be ignored. Set distance_threshold per Route.",
DeprecationWarning,
stacklevel=2,
)
values.pop("distance_threshold")
return values
class SemanticRouterIndexSchema(IndexSchema):
"""Customized index schema for SemanticRouter."""
@classmethod
def from_params(cls, name: str, vector_dims: int, dtype: str):
"""Create an index schema based on router name and vector dimensions.
Args:
name (str): The name of the index.
vector_dims (int): The dimensions of the vectors.
Returns:
SemanticRouterIndexSchema: The constructed index schema.
"""
return cls(
index={"name": name, "prefix": name}, # type: ignore
fields=[ # type: ignore
{"name": "reference_id", "type": "tag"},
{"name": "route_name", "type": "tag"},
{"name": "reference", "type": "text"},
{
"name": ROUTE_VECTOR_FIELD_NAME,
"type": "vector",
"attrs": {
"algorithm": "flat",
"dims": vector_dims,
"distance_metric": "cosine",
"datatype": dtype,
},
},
],
)