|
|
|
|
# app/bootstrap.py
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
import time
|
|
|
|
|
import os, json, re, requests
|
|
|
|
|
from sqlalchemy import create_engine, text
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
import requests as http_requests
|
|
|
|
|
import google.auth
|
|
|
|
|
import google.auth.transport.requests
|
|
|
|
|
from app.vertex_client import get_access_token
|
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
|
|
|
|
engine = create_engine(DATABASE_URL)
|
|
|
|
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
|
|
|
|
|
PROJECT_ID = os.getenv("GOOGLE_PROJECT_ID")
|
|
|
|
|
LOCATION = "us-central1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VALID_CATEGORIES = {
|
|
|
|
|
"Performance", "Financial", "Risk", "Maintenance", "Benefits",
|
|
|
|
|
"Time", "Requirements", "Scalability", "Alternatives", "Usability",
|
|
|
|
|
"Security", "Reliability", "Support", "Sustainability"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_json_response(raw: str) -> str:
|
|
|
|
|
"""Strip markdown, whitespace, and extract pure JSON."""
|
|
|
|
|
# Remove markdown code blocks
|
|
|
|
|
raw = re.sub(r"```json\s*", "", raw)
|
|
|
|
|
raw = re.sub(r"```\s*", "", raw)
|
|
|
|
|
raw = raw.strip()
|
|
|
|
|
|
|
|
|
|
# Extract first JSON object found
|
|
|
|
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
|
|
|
if match:
|
|
|
|
|
return match.group(0)
|
|
|
|
|
raise ValueError("No JSON object found in response")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_schema(data: dict, query: str) -> dict:
|
|
|
|
|
cleaned = {}
|
|
|
|
|
query_words = set(query.lower().split())
|
|
|
|
|
|
|
|
|
|
# Define topic-aware category relevance
|
|
|
|
|
IRRELEVANT_COMBOS = {
|
|
|
|
|
"shoes": ["processor", "ram", "gpu", "cpu", "battery", "charging"],
|
|
|
|
|
"food": ["processor", "gpu", "engine", "torque", "horsepower"],
|
|
|
|
|
"college": ["torque", "engine", "gpu", "charging speed"],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Get blocked terms for this query
|
|
|
|
|
blocked = []
|
|
|
|
|
for topic, terms in IRRELEVANT_COMBOS.items():
|
|
|
|
|
if topic in query.lower():
|
|
|
|
|
blocked.extend(terms)
|
|
|
|
|
|
|
|
|
|
for category, attributes in data.items():
|
|
|
|
|
cat = category.strip().title()
|
|
|
|
|
if cat not in VALID_CATEGORIES:
|
|
|
|
|
print(f"Skipping unknown category: {cat}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
attrs = []
|
|
|
|
|
for a in attributes:
|
|
|
|
|
if not isinstance(a, str) or not a.strip():
|
|
|
|
|
continue
|
|
|
|
|
# ✅ Reject attributes containing blocked terms
|
|
|
|
|
a_lower = a.lower()
|
|
|
|
|
if any(b in a_lower for b in blocked):
|
|
|
|
|
print(f"❌ Rejected irrelevant attribute: {a}")
|
|
|
|
|
continue
|
|
|
|
|
attrs.append(a.strip())
|
|
|
|
|
|
|
|
|
|
attrs = attrs[:12]
|
|
|
|
|
if attrs:
|
|
|
|
|
cleaned[cat] = attrs
|
|
|
|
|
|
|
|
|
|
if not cleaned:
|
|
|
|
|
raise ValueError("No valid categories found after validation")
|
|
|
|
|
|
|
|
|
|
return cleaned
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_gemini(query: str) -> dict:
|
|
|
|
|
prompt = f"""You are a world-class decision analysis expert.
|
|
|
|
|
|
|
|
|
|
Task: Generate a COMPREHENSIVE list of evaluation criteria for: "{query}"
|
|
|
|
|
|
|
|
|
|
MANDATORY RULES:
|
|
|
|
|
- Select 6-8 categories from the allowed list
|
|
|
|
|
- Generate EXACTLY 8-10 attributes per category — this is mandatory
|
|
|
|
|
- Each attribute must be SPECIFIC and MEASURABLE for "{query}"
|
|
|
|
|
- First 3 attributes in EVERY category must be the MOST SEARCHED specs
|
|
|
|
|
- For tech products: always start with Processor, RAM, Battery, Display, Camera
|
|
|
|
|
- For vehicles: always start with Engine, Mileage, Price, Safety
|
|
|
|
|
- Order strictly by: most searched → most compared → most reviewed
|
|
|
|
|
- Use concise names (2-5 words max)
|
|
|
|
|
- DO NOT generate less than 8 attributes per category
|
|
|
|
|
|
|
|
|
|
Allowed category keys:
|
|
|
|
|
Performance, Financial, Risk, Maintenance, Benefits, Time, Requirements,
|
|
|
|
|
Scalability, Alternatives, Usability, Security, Reliability, Support, Sustainability
|
|
|
|
|
|
|
|
|
|
Example of CORRECT format with enough attributes:
|
|
|
|
|
{{"Performance":["Engine Power","Torque Output","Top Speed","0-100 kmph Time","Fuel Efficiency","Gear Smoothness","Braking Distance","Tyre Grip","Suspension Quality","NVH Levels"],"Financial":["Ex-showroom Price","On-road Price","EMI Options","Insurance Cost","Fuel Cost Monthly","Resale Value","Maintenance Cost","Road Tax","Accessories Cost","Total Ownership Cost"],"Reliability":["Engine Reliability","Electrical Issues","Common Problems","Long Term Durability","Brand Track Record","Owner Satisfaction","Recall History","Service Quality","Spare Parts Life","Warranty Claims"]}}
|
|
|
|
|
|
|
|
|
|
Now generate for: "{query}"
|
|
|
|
|
Return ONLY valid JSON. No markdown. No explanation. Minimum 8 attributes per category."""
|
|
|
|
|
|
|
|
|
|
url = f"https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/gemini-2.5-flash-lite:generateContent"
|
|
|
|
|
|
|
|
|
|
for attempt in range(3):
|
|
|
|
|
try:
|
|
|
|
|
res = http_requests.post(
|
|
|
|
|
url,
|
|
|
|
|
headers={
|
|
|
|
|
"Authorization": f"Bearer {get_access_token()}",
|
|
|
|
|
"Content-Type": "application/json"
|
|
|
|
|
},
|
|
|
|
|
json={
|
|
|
|
|
"contents": [{"role": "user", "parts": [{"text": prompt}]}],
|
|
|
|
|
"generationConfig": {
|
|
|
|
|
"temperature": 0.1, # lower = more accurate, less random
|
|
|
|
|
"maxOutputTokens": 1024
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
timeout=30
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if res.status_code == 429:
|
|
|
|
|
wait = 2 ** attempt
|
|
|
|
|
print(f"Rate limited, waiting {wait}s... (attempt {attempt + 1})")
|
|
|
|
|
time.sleep(wait)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
print("Status Code:", res.status_code)
|
|
|
|
|
if res.status_code != 200:
|
|
|
|
|
print("❌ ERROR RESPONSE:")
|
|
|
|
|
print(res.text)
|
|
|
|
|
raise RuntimeError("Vertex API failed")
|
|
|
|
|
|
|
|
|
|
data_json = res.json()
|
|
|
|
|
print("✅ RAW RESPONSE:", str(data_json)[:500])
|
|
|
|
|
|
|
|
|
|
raw = res.json()["candidates"][0]["content"]["parts"][0]["text"]
|
|
|
|
|
clean = clean_json_response(raw)
|
|
|
|
|
data = json.loads(clean)
|
|
|
|
|
validated = validate_schema(data,query)
|
|
|
|
|
print(f"Gemini generated {sum(len(v) for v in validated.values())} attributes across {len(validated)} categories")
|
|
|
|
|
return validated
|
|
|
|
|
|
|
|
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
|
|
|
print(f"Attempt {attempt + 1} failed: {e}")
|
|
|
|
|
if attempt == 2:
|
|
|
|
|
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
|
|
|
|
|
|
|
|
|
|
raise RuntimeError("Gemini failed after 3 attempts: rate limit")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bootstrap_domain(query: str):
|
|
|
|
|
data = None
|
|
|
|
|
for attempt in range(3):
|
|
|
|
|
try:
|
|
|
|
|
data = call_gemini(query)
|
|
|
|
|
print(f"Gemini response validated on attempt {attempt + 1}")
|
|
|
|
|
break
|
|
|
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
|
|
|
print(f"Attempt {attempt + 1} failed: {e}")
|
|
|
|
|
if attempt == 2:
|
|
|
|
|
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
|
|
|
|
|
|
|
|
|
|
if not data:
|
|
|
|
|
raise RuntimeError("No data generated")
|
|
|
|
|
|
|
|
|
|
# ✅ Quality gate — reject low quality bootstraps
|
|
|
|
|
total_attrs = sum(len(v) for v in data.values())
|
|
|
|
|
if total_attrs < 10:
|
|
|
|
|
raise ValueError(f"Quality gate failed: only {total_attrs} attributes generated")
|
|
|
|
|
|
|
|
|
|
# ✅ Duplicate domain detection — check similarity before storing
|
|
|
|
|
model_local = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
domain_embedding = model_local.encode(query).tolist()
|
|
|
|
|
|
|
|
|
|
with engine.begin() as conn:
|
|
|
|
|
# Check if very similar domain already exists
|
|
|
|
|
existing = conn.execute(text("""
|
|
|
|
|
SELECT name, embedding <-> CAST(:emb AS vector) AS distance
|
|
|
|
|
FROM domains
|
|
|
|
|
ORDER BY distance
|
|
|
|
|
LIMIT 1
|
|
|
|
|
"""), {"emb": str(domain_embedding)}).fetchone()
|
|
|
|
|
|
|
|
|
|
if existing and existing.distance < 0.15:
|
|
|
|
|
print(f"⚠️ Similar domain already exists: '{existing.name}' (distance: {existing.distance:.3f}) — skipping bootstrap")
|
|
|
|
|
return # ✅ Don't store duplicate
|
|
|
|
|
|
|
|
|
|
# Store domain
|
|
|
|
|
domain_id = conn.execute(text("""
|
|
|
|
|
INSERT INTO domains (name, embedding)
|
|
|
|
|
VALUES (:n, CAST(:e AS vector))
|
|
|
|
|
ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name
|
|
|
|
|
RETURNING id
|
|
|
|
|
"""), {"n": query, "e": str(domain_embedding)}).scalar()
|
|
|
|
|
|
|
|
|
|
for group, attrs in data.items():
|
|
|
|
|
group_id = conn.execute(text("""
|
|
|
|
|
INSERT INTO dimension_groups (domain_id, name)
|
|
|
|
|
VALUES (:d, :g)
|
|
|
|
|
ON CONFLICT (domain_id, name) DO UPDATE SET name = EXCLUDED.name
|
|
|
|
|
RETURNING id
|
|
|
|
|
"""), {"d": domain_id, "g": group}).scalar()
|
|
|
|
|
|
|
|
|
|
for attr in attrs:
|
|
|
|
|
emb = model_local.encode(attr).tolist()
|
|
|
|
|
conn.execute(text("""
|
|
|
|
|
INSERT INTO attributes (group_id, name, embedding)
|
|
|
|
|
VALUES (:gid, :name, CAST(:emb AS vector))
|
|
|
|
|
ON CONFLICT (group_id, name) DO NOTHING
|
|
|
|
|
"""), {"gid": group_id, "name": attr, "emb": str(emb)})
|
|
|
|
|
|
|
|
|
|
print(f"✅ Domain bootstrapped: {query} ({total_attrs} attributes)")
|