You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.1 KiB
102 lines
3.1 KiB
# app/semantic_cache.py
|
|
import json
|
|
import numpy as np
|
|
import redis
|
|
from redis.commands.search.field import VectorField, TextField
|
|
from redis.commands.search.index_definition import IndexDefinition, IndexType
|
|
from redis.commands.search.query import Query
|
|
|
|
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=False)
|
|
|
|
SIMILARITY_THRESHOLD = 0.15
|
|
VECTOR_DIM = 384 # all-MiniLM-L6-v2 output size
|
|
INDEX_NAME = "semantic_cache_idx"
|
|
|
|
|
|
def create_index_if_not_exists():
|
|
"""Create HNSW vector index in Redis on startup."""
|
|
try:
|
|
redis_client.ft(INDEX_NAME).info()
|
|
print("Semantic cache index already exists")
|
|
except Exception:
|
|
schema = [
|
|
TextField("query"),
|
|
TextField("domain"),
|
|
VectorField(
|
|
"embedding",
|
|
"HNSW", # ← O(log N) search
|
|
{
|
|
"TYPE": "FLOAT32",
|
|
"DIM": VECTOR_DIM,
|
|
"DISTANCE_METRIC": "COSINE",
|
|
"M": 16,
|
|
"EF_CONSTRUCTION": 200,
|
|
}
|
|
)
|
|
]
|
|
redis_client.ft(INDEX_NAME).create_index(
|
|
schema,
|
|
definition=IndexDefinition(
|
|
prefix=["semcache:"],
|
|
index_type=IndexType.HASH
|
|
)
|
|
)
|
|
print("Semantic cache HNSW index created")
|
|
|
|
|
|
def get_semantic_cache(embedding: list, domain: str = None):
|
|
"""
|
|
Search Redis HNSW index with optional domain filter.
|
|
"""
|
|
query_vector = np.array(embedding, dtype=np.float32).tobytes()
|
|
|
|
# Build filter expression for metadata
|
|
if domain:
|
|
filter_expr = f"@domain:{{{domain.replace(' ', '_')}}}"
|
|
query_str = f"({filter_expr})=>[KNN 1 @embedding $vec AS distance]"
|
|
else:
|
|
query_str = "*=>[KNN 1 @embedding $vec AS distance]"
|
|
|
|
query = (
|
|
Query(query_str)
|
|
.sort_by("distance")
|
|
.return_fields("query", "suggestions", "distance", "domain")
|
|
.dialect(2)
|
|
)
|
|
|
|
try:
|
|
results = redis_client.ft(INDEX_NAME).search(
|
|
query, query_params={"vec": query_vector}
|
|
)
|
|
except Exception as e:
|
|
print(f"Semantic cache search error: {e}")
|
|
return None
|
|
|
|
if not results.docs:
|
|
print("Semantic cache MISS (empty index)")
|
|
return None
|
|
|
|
top = results.docs[0]
|
|
distance = float(top.distance)
|
|
|
|
if distance < SIMILARITY_THRESHOLD:
|
|
print(f"Semantic cache HIT (distance: {distance:.4f}, domain: {domain})")
|
|
return json.loads(top.suggestions)
|
|
|
|
print(f"Semantic cache MISS (distance: {distance:.4f})")
|
|
return None
|
|
|
|
|
|
def set_semantic_cache(query: str, embedding: list, suggestions: list, domain: str = None, ttl: int = 86400):
|
|
"""Store with domain metadata for filtering."""
|
|
key = f"semcache:{query}"
|
|
vector = np.array(embedding, dtype=np.float32).tobytes()
|
|
|
|
redis_client.hset(key, mapping={
|
|
"query": query,
|
|
"embedding": vector,
|
|
"suggestions": json.dumps(suggestions),
|
|
"domain": domain.replace(" ", "_") if domain else "unknown"
|
|
})
|
|
redis_client.expire(key, ttl)
|
|
print(f"Semantic cached: {query} (domain: {domain})") |