You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
3.1 KiB

# app/semantic_cache.py
import json
import numpy as np
import redis
from redis.commands.search.field import VectorField, TextField
from redis.commands.search.index_definition import IndexDefinition, IndexType
from redis.commands.search.query import Query
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=False)
SIMILARITY_THRESHOLD = 0.15
VECTOR_DIM = 384 # all-MiniLM-L6-v2 output size
INDEX_NAME = "semantic_cache_idx"
def create_index_if_not_exists():
"""Create HNSW vector index in Redis on startup."""
try:
redis_client.ft(INDEX_NAME).info()
print("Semantic cache index already exists")
except Exception:
schema = [
TextField("query"),
TextField("domain"),
VectorField(
"embedding",
"HNSW", # ← O(log N) search
{
"TYPE": "FLOAT32",
"DIM": VECTOR_DIM,
"DISTANCE_METRIC": "COSINE",
"M": 16,
"EF_CONSTRUCTION": 200,
}
)
]
redis_client.ft(INDEX_NAME).create_index(
schema,
definition=IndexDefinition(
prefix=["semcache:"],
index_type=IndexType.HASH
)
)
print("Semantic cache HNSW index created")
def get_semantic_cache(embedding: list, domain: str = None):
"""
Search Redis HNSW index with optional domain filter.
"""
query_vector = np.array(embedding, dtype=np.float32).tobytes()
# Build filter expression for metadata
if domain:
filter_expr = f"@domain:{{{domain.replace(' ', '_')}}}"
query_str = f"({filter_expr})=>[KNN 1 @embedding $vec AS distance]"
else:
query_str = "*=>[KNN 1 @embedding $vec AS distance]"
query = (
Query(query_str)
.sort_by("distance")
.return_fields("query", "suggestions", "distance", "domain")
.dialect(2)
)
try:
results = redis_client.ft(INDEX_NAME).search(
query, query_params={"vec": query_vector}
)
except Exception as e:
print(f"Semantic cache search error: {e}")
return None
if not results.docs:
print("Semantic cache MISS (empty index)")
return None
top = results.docs[0]
distance = float(top.distance)
if distance < SIMILARITY_THRESHOLD:
print(f"Semantic cache HIT (distance: {distance:.4f}, domain: {domain})")
return json.loads(top.suggestions)
print(f"Semantic cache MISS (distance: {distance:.4f})")
return None
def set_semantic_cache(query: str, embedding: list, suggestions: list, domain: str = None, ttl: int = 86400):
"""Store with domain metadata for filtering."""
key = f"semcache:{query}"
vector = np.array(embedding, dtype=np.float32).tobytes()
redis_client.hset(key, mapping={
"query": query,
"embedding": vector,
"suggestions": json.dumps(suggestions),
"domain": domain.replace(" ", "_") if domain else "unknown"
})
redis_client.expire(key, ttl)
print(f"Semantic cached: {query} (domain: {domain})")