You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

255 lines
9.0 KiB

3 days ago
# app/main.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import create_engine, text
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
from dotenv import load_dotenv
from app.bootstrap import bootstrap_domain
from app.db_schema import ensure_schema
from app.normalizer import normalize_query
from app.embedding_cache import get_or_encode
from app.semantic_cache import get_semantic_cache, set_semantic_cache, create_index_if_not_exists
import threading
import requests as http_requests
import os
import re
import time
from app.vertex_client import get_access_token
load_dotenv()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DATABASE_URL = "postgresql://postgres:postgres@localhost:5432/decision_engine"
engine = create_engine(DATABASE_URL)
model = SentenceTransformer("all-MiniLM-L6-v2")
GOOGLE_PROJECT_ID = "sylvan-deck-387207"
LOCATION = "us-central1"
# ── Startup ──────────────────────────────────────────
ensure_schema(engine)
create_index_if_not_exists()
print("✅ System ready")
# ── Async cache writer ────────────────────────────────
def write_cache_async(query: str, embedding: list, suggestions: list, domain: str = None):
def _write():
try:
set_semantic_cache(query, embedding, suggestions, domain=domain)
except Exception as e:
print(f"⚠️ Cache write failed: {e}")
threading.Thread(target=_write, daemon=True).start()
# ── Gibberish detector ────────────────────────────────
def is_gibberish(text: str) -> bool:
words = text.lower().split()
if not words:
return True
gibberish_count = 0
for word in words:
if len(word) <= 3: # allow short words: ev, ai, bmw, suv
continue
if re.search(r'[^aeiou]{6,}', word):
gibberish_count += 1
continue
vowels = len(re.findall(r'[aeiou]', word))
if len(word) > 5 and vowels / len(word) < 0.1:
gibberish_count += 1
continue
if re.search(r'[a-z]\d{3,}[a-z]|[0-9]{4,}[a-z]', word):
gibberish_count += 1
continue
return gibberish_count > len(words) / 2
# ── /suggest endpoint ─────────────────────────────────
@app.get("/suggest")
def suggest(query: str, offset: int = 0, limit: int = 15):
if len(query.strip()) < 2:
return {"suggestions": [], "cache": "skip"}
# Block gibberish server-side
if is_gibberish(query.strip()):
return {"suggestions": [], "cache": "gibberish"}
normalized = normalize_query(query.strip())
print(f"📝 Normalized: '{query}''{normalized}'")
embedding = get_or_encode(normalized, model)
word_count = len(query.strip().split())
# Use semantic cache only for full queries (3+ words) on first page
if word_count >= 3 and offset == 0:
cached = get_semantic_cache(embedding, domain=normalized)
if cached:
print(f"✅ Semantic cache HIT")
return {"suggestions": cached[offset:offset + limit], "cache": "semantic_hit"}
emb_param = str(embedding)
with engine.begin() as conn:
domain_row = conn.execute(text("""
SELECT id, name,
embedding <-> CAST(:emb AS vector) AS distance
FROM domains
ORDER BY distance
LIMIT 1
"""), {"emb": emb_param}).fetchone()
if domain_row is None or domain_row.distance > 0.8:
if not is_gibberish(query):
try:
bootstrap_domain(query)
print(f"✅ Bootstrapped: {query}")
except Exception as e:
print(f"⚠️ Bootstrap failed: {e}")
domain_row = conn.execute(text("""
SELECT id, name
FROM domains
ORDER BY embedding <-> CAST(:emb AS vector)
LIMIT 1
"""), {"emb": emb_param}).fetchone()
if domain_row is None:
return {"suggestions": [], "cache": "no_domain"}
# ✅ continues below to fetch attributes
else:
return {"suggestions": [], "cache": "no_domain"}
results = conn.execute(text("""
SELECT a.name,
1 - (a.embedding <-> CAST(:emb AS vector)) AS score
FROM attributes a
JOIN dimension_groups g ON a.group_id = g.id
WHERE g.domain_id = :domain_id
ORDER BY score DESC
LIMIT :limit OFFSET :offset
"""), {"emb": emb_param, "domain_id": domain_row.id,
"limit": limit, "offset": offset})
suggestions = [r[0] for r in results]
domain_name = domain_row.name
# Deduplicate
seen = set()
ranked = []
for name in suggestions:
if name.lower() not in seen:
seen.add(name.lower())
ranked.append(name)
# Cache only full queries on first page
if word_count >= 3 and offset == 0:
write_cache_async(normalized, embedding, ranked, domain=domain_name)
return {"suggestions": ranked, "cache": "miss", "domain": domain_name}
# ── /generate endpoint ────────────────────────────────
class GenerateRequest(BaseModel):
query: str
selected_attributes: list[str]
chat_history: list[dict] = []
@app.post("/generate")
def generate(request: GenerateRequest):
if not request.query.strip():
return {"answer": "Please enter a query."}
history_text = ""
if request.chat_history:
history_text = "\n".join([
f"User: {h['query']}\nCriteria: {', '.join(h.get('chips', []))}\nAnswer: {h['answer']}"
for h in request.chat_history
])
history_text = f"Previous conversation:\n{history_text}\n\n"
attributes = ", ".join(request.selected_attributes) if request.selected_attributes else "general evaluation"
prompt = f"""{history_text}USER QUESTION: "{request.query}"
EVALUATION CRITERIA SELECTED: {attributes}
You are an expert advisor. The user has specifically asked about "{request.query}".
Your job is to answer the user's question "{request.query}" and analyze it through each of the selected criteria.
Answer the question "{request.query}" directly first in 2-3 sentences.
Then for each criterion in [{attributes}], explain how it applies specifically to "{request.query}".
Format your response as:
## About: {request.query}
[Direct answer to the user's question in 2-3 sentences]
---
[For each selected criterion:]
**[Criterion Name]**
- How this applies to "{request.query}"
- Specific facts, numbers, or data
- Recommendation
---
## Bottom Line
[2-3 sentence summary answering: should the user go with "{request.query}"? What should they prioritize?]
STRICT RULES:
- Every sentence must be about "{request.query}" specifically
- Never give generic advice not related to "{request.query}"
- If unsure about a fact, say "verify on official website"
- Use real numbers where confident
- Total response under 400 words"""
url = f"https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{GOOGLE_PROJECT_ID}/locations/{LOCATION}/publishers/google/models/gemini-2.5-flash-lite:generateContent"
print(f"🔄 Calling Vertex AI for: {request.query}")
for attempt in range(3):
try:
res = http_requests.post(
url,
headers={
"Authorization": f"Bearer {get_access_token()}",
"Content-Type": "application/json"
},
json={
"contents": [{"role": "user", "parts": [{"text": prompt}]}],
"generationConfig": {"temperature": 0.3, "maxOutputTokens": 800}
},
timeout=30
)
if res.status_code == 429:
wait = 2 ** attempt
print(f"⏳ Rate limited, waiting {wait}s...")
time.sleep(wait)
continue
print(f"✅ Vertex response: {res.status_code}")
res.raise_for_status()
answer = res.json()["candidates"][0]["content"]["parts"][0]["text"]
return {"answer": answer}
except http_requests.exceptions.Timeout:
print(f"⏰ Timeout on attempt {attempt + 1}")
if attempt == 2:
return {"answer": "Request timed out. Please try again."}
time.sleep(2)
except Exception as e:
print(f"❌ Error: {e}")
return {"answer": "Error getting response. Please try again."}
return {"answer": "Could not get response after 3 attempts. Please try again."}