You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
9.0 KiB
255 lines
9.0 KiB
|
3 days ago
|
# app/main.py
|
||
|
|
from fastapi import FastAPI
|
||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
|
from sqlalchemy import create_engine, text
|
||
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
from app.bootstrap import bootstrap_domain
|
||
|
|
from app.db_schema import ensure_schema
|
||
|
|
from app.normalizer import normalize_query
|
||
|
|
from app.embedding_cache import get_or_encode
|
||
|
|
from app.semantic_cache import get_semantic_cache, set_semantic_cache, create_index_if_not_exists
|
||
|
|
import threading
|
||
|
|
import requests as http_requests
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import time
|
||
|
|
from app.vertex_client import get_access_token
|
||
|
|
|
||
|
|
load_dotenv()
|
||
|
|
|
||
|
|
app = FastAPI()
|
||
|
|
|
||
|
|
app.add_middleware(
|
||
|
|
CORSMiddleware,
|
||
|
|
allow_origins=["*"],
|
||
|
|
allow_credentials=True,
|
||
|
|
allow_methods=["*"],
|
||
|
|
allow_headers=["*"],
|
||
|
|
)
|
||
|
|
|
||
|
|
DATABASE_URL = "postgresql://postgres:postgres@localhost:5432/decision_engine"
|
||
|
|
engine = create_engine(DATABASE_URL)
|
||
|
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||
|
|
|
||
|
|
GOOGLE_PROJECT_ID = "sylvan-deck-387207"
|
||
|
|
LOCATION = "us-central1"
|
||
|
|
|
||
|
|
# ── Startup ──────────────────────────────────────────
|
||
|
|
ensure_schema(engine)
|
||
|
|
create_index_if_not_exists()
|
||
|
|
print("✅ System ready")
|
||
|
|
|
||
|
|
|
||
|
|
# ── Async cache writer ────────────────────────────────
|
||
|
|
def write_cache_async(query: str, embedding: list, suggestions: list, domain: str = None):
|
||
|
|
def _write():
|
||
|
|
try:
|
||
|
|
set_semantic_cache(query, embedding, suggestions, domain=domain)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"⚠️ Cache write failed: {e}")
|
||
|
|
threading.Thread(target=_write, daemon=True).start()
|
||
|
|
|
||
|
|
|
||
|
|
# ── Gibberish detector ────────────────────────────────
|
||
|
|
def is_gibberish(text: str) -> bool:
|
||
|
|
words = text.lower().split()
|
||
|
|
if not words:
|
||
|
|
return True
|
||
|
|
|
||
|
|
gibberish_count = 0
|
||
|
|
for word in words:
|
||
|
|
if len(word) <= 3: # allow short words: ev, ai, bmw, suv
|
||
|
|
continue
|
||
|
|
if re.search(r'[^aeiou]{6,}', word):
|
||
|
|
gibberish_count += 1
|
||
|
|
continue
|
||
|
|
vowels = len(re.findall(r'[aeiou]', word))
|
||
|
|
if len(word) > 5 and vowels / len(word) < 0.1:
|
||
|
|
gibberish_count += 1
|
||
|
|
continue
|
||
|
|
if re.search(r'[a-z]\d{3,}[a-z]|[0-9]{4,}[a-z]', word):
|
||
|
|
gibberish_count += 1
|
||
|
|
continue
|
||
|
|
|
||
|
|
return gibberish_count > len(words) / 2
|
||
|
|
|
||
|
|
|
||
|
|
# ── /suggest endpoint ─────────────────────────────────
|
||
|
|
@app.get("/suggest")
|
||
|
|
def suggest(query: str, offset: int = 0, limit: int = 15):
|
||
|
|
if len(query.strip()) < 2:
|
||
|
|
return {"suggestions": [], "cache": "skip"}
|
||
|
|
|
||
|
|
# Block gibberish server-side
|
||
|
|
if is_gibberish(query.strip()):
|
||
|
|
return {"suggestions": [], "cache": "gibberish"}
|
||
|
|
|
||
|
|
normalized = normalize_query(query.strip())
|
||
|
|
print(f"📝 Normalized: '{query}' → '{normalized}'")
|
||
|
|
|
||
|
|
embedding = get_or_encode(normalized, model)
|
||
|
|
word_count = len(query.strip().split())
|
||
|
|
|
||
|
|
# Use semantic cache only for full queries (3+ words) on first page
|
||
|
|
if word_count >= 3 and offset == 0:
|
||
|
|
cached = get_semantic_cache(embedding, domain=normalized)
|
||
|
|
if cached:
|
||
|
|
print(f"✅ Semantic cache HIT")
|
||
|
|
return {"suggestions": cached[offset:offset + limit], "cache": "semantic_hit"}
|
||
|
|
|
||
|
|
emb_param = str(embedding)
|
||
|
|
with engine.begin() as conn:
|
||
|
|
domain_row = conn.execute(text("""
|
||
|
|
SELECT id, name,
|
||
|
|
embedding <-> CAST(:emb AS vector) AS distance
|
||
|
|
FROM domains
|
||
|
|
ORDER BY distance
|
||
|
|
LIMIT 1
|
||
|
|
"""), {"emb": emb_param}).fetchone()
|
||
|
|
|
||
|
|
if domain_row is None or domain_row.distance > 0.8:
|
||
|
|
if not is_gibberish(query):
|
||
|
|
try:
|
||
|
|
bootstrap_domain(query)
|
||
|
|
print(f"✅ Bootstrapped: {query}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"⚠️ Bootstrap failed: {e}")
|
||
|
|
|
||
|
|
domain_row = conn.execute(text("""
|
||
|
|
SELECT id, name
|
||
|
|
FROM domains
|
||
|
|
ORDER BY embedding <-> CAST(:emb AS vector)
|
||
|
|
LIMIT 1
|
||
|
|
"""), {"emb": emb_param}).fetchone()
|
||
|
|
|
||
|
|
if domain_row is None:
|
||
|
|
return {"suggestions": [], "cache": "no_domain"}
|
||
|
|
# ✅ continues below to fetch attributes
|
||
|
|
else:
|
||
|
|
return {"suggestions": [], "cache": "no_domain"}
|
||
|
|
|
||
|
|
results = conn.execute(text("""
|
||
|
|
SELECT a.name,
|
||
|
|
1 - (a.embedding <-> CAST(:emb AS vector)) AS score
|
||
|
|
FROM attributes a
|
||
|
|
JOIN dimension_groups g ON a.group_id = g.id
|
||
|
|
WHERE g.domain_id = :domain_id
|
||
|
|
ORDER BY score DESC
|
||
|
|
LIMIT :limit OFFSET :offset
|
||
|
|
"""), {"emb": emb_param, "domain_id": domain_row.id,
|
||
|
|
"limit": limit, "offset": offset})
|
||
|
|
|
||
|
|
suggestions = [r[0] for r in results]
|
||
|
|
domain_name = domain_row.name
|
||
|
|
|
||
|
|
# Deduplicate
|
||
|
|
seen = set()
|
||
|
|
ranked = []
|
||
|
|
for name in suggestions:
|
||
|
|
if name.lower() not in seen:
|
||
|
|
seen.add(name.lower())
|
||
|
|
ranked.append(name)
|
||
|
|
|
||
|
|
# Cache only full queries on first page
|
||
|
|
if word_count >= 3 and offset == 0:
|
||
|
|
write_cache_async(normalized, embedding, ranked, domain=domain_name)
|
||
|
|
|
||
|
|
return {"suggestions": ranked, "cache": "miss", "domain": domain_name}
|
||
|
|
|
||
|
|
|
||
|
|
# ── /generate endpoint ────────────────────────────────
|
||
|
|
class GenerateRequest(BaseModel):
|
||
|
|
query: str
|
||
|
|
selected_attributes: list[str]
|
||
|
|
chat_history: list[dict] = []
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/generate")
|
||
|
|
def generate(request: GenerateRequest):
|
||
|
|
if not request.query.strip():
|
||
|
|
return {"answer": "Please enter a query."}
|
||
|
|
|
||
|
|
history_text = ""
|
||
|
|
if request.chat_history:
|
||
|
|
history_text = "\n".join([
|
||
|
|
f"User: {h['query']}\nCriteria: {', '.join(h.get('chips', []))}\nAnswer: {h['answer']}"
|
||
|
|
for h in request.chat_history
|
||
|
|
])
|
||
|
|
history_text = f"Previous conversation:\n{history_text}\n\n"
|
||
|
|
|
||
|
|
attributes = ", ".join(request.selected_attributes) if request.selected_attributes else "general evaluation"
|
||
|
|
|
||
|
|
prompt = f"""{history_text}USER QUESTION: "{request.query}"
|
||
|
|
EVALUATION CRITERIA SELECTED: {attributes}
|
||
|
|
|
||
|
|
You are an expert advisor. The user has specifically asked about "{request.query}".
|
||
|
|
Your job is to answer the user's question "{request.query}" and analyze it through each of the selected criteria.
|
||
|
|
|
||
|
|
Answer the question "{request.query}" directly first in 2-3 sentences.
|
||
|
|
Then for each criterion in [{attributes}], explain how it applies specifically to "{request.query}".
|
||
|
|
|
||
|
|
Format your response as:
|
||
|
|
|
||
|
|
## About: {request.query}
|
||
|
|
[Direct answer to the user's question in 2-3 sentences]
|
||
|
|
|
||
|
|
---
|
||
|
|
[For each selected criterion:]
|
||
|
|
**[Criterion Name]**
|
||
|
|
- How this applies to "{request.query}"
|
||
|
|
- Specific facts, numbers, or data
|
||
|
|
- Recommendation
|
||
|
|
|
||
|
|
---
|
||
|
|
## Bottom Line
|
||
|
|
[2-3 sentence summary answering: should the user go with "{request.query}"? What should they prioritize?]
|
||
|
|
|
||
|
|
STRICT RULES:
|
||
|
|
- Every sentence must be about "{request.query}" specifically
|
||
|
|
- Never give generic advice not related to "{request.query}"
|
||
|
|
- If unsure about a fact, say "verify on official website"
|
||
|
|
- Use real numbers where confident
|
||
|
|
- Total response under 400 words"""
|
||
|
|
|
||
|
|
url = f"https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{GOOGLE_PROJECT_ID}/locations/{LOCATION}/publishers/google/models/gemini-2.5-flash-lite:generateContent"
|
||
|
|
|
||
|
|
print(f"🔄 Calling Vertex AI for: {request.query}")
|
||
|
|
|
||
|
|
for attempt in range(3):
|
||
|
|
try:
|
||
|
|
res = http_requests.post(
|
||
|
|
url,
|
||
|
|
headers={
|
||
|
|
"Authorization": f"Bearer {get_access_token()}",
|
||
|
|
"Content-Type": "application/json"
|
||
|
|
},
|
||
|
|
json={
|
||
|
|
"contents": [{"role": "user", "parts": [{"text": prompt}]}],
|
||
|
|
"generationConfig": {"temperature": 0.3, "maxOutputTokens": 800}
|
||
|
|
},
|
||
|
|
timeout=30
|
||
|
|
)
|
||
|
|
|
||
|
|
if res.status_code == 429:
|
||
|
|
wait = 2 ** attempt
|
||
|
|
print(f"⏳ Rate limited, waiting {wait}s...")
|
||
|
|
time.sleep(wait)
|
||
|
|
continue
|
||
|
|
|
||
|
|
print(f"✅ Vertex response: {res.status_code}")
|
||
|
|
res.raise_for_status()
|
||
|
|
answer = res.json()["candidates"][0]["content"]["parts"][0]["text"]
|
||
|
|
return {"answer": answer}
|
||
|
|
|
||
|
|
except http_requests.exceptions.Timeout:
|
||
|
|
print(f"⏰ Timeout on attempt {attempt + 1}")
|
||
|
|
if attempt == 2:
|
||
|
|
return {"answer": "Request timed out. Please try again."}
|
||
|
|
time.sleep(2)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error: {e}")
|
||
|
|
return {"answer": "Error getting response. Please try again."}
|
||
|
|
|
||
|
|
return {"answer": "Could not get response after 3 attempts. Please try again."}
|