You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
8.7 KiB

# app/bootstrap.py
from dotenv import load_dotenv
import time
import os, json, re, requests
from sqlalchemy import create_engine, text
from sentence_transformers import SentenceTransformer
import requests as http_requests
import google.auth
import google.auth.transport.requests
from app.vertex_client import get_access_token
load_dotenv()
DATABASE_URL = os.getenv("DATABASE_URL")
engine = create_engine(DATABASE_URL)
model = SentenceTransformer("all-MiniLM-L6-v2")
PROJECT_ID = os.getenv("GOOGLE_PROJECT_ID")
LOCATION = "us-central1"
VALID_CATEGORIES = {
"Performance", "Financial", "Risk", "Maintenance", "Benefits",
"Time", "Requirements", "Scalability", "Alternatives", "Usability",
"Security", "Reliability", "Support", "Sustainability"
}
def clean_json_response(raw: str) -> str:
"""Strip markdown, whitespace, and extract pure JSON."""
# Remove markdown code blocks
raw = re.sub(r"```json\s*", "", raw)
raw = re.sub(r"```\s*", "", raw)
raw = raw.strip()
# Extract first JSON object found
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
return match.group(0)
raise ValueError("No JSON object found in response")
def validate_schema(data: dict, query: str) -> dict:
cleaned = {}
query_words = set(query.lower().split())
# Define topic-aware category relevance
IRRELEVANT_COMBOS = {
"shoes": ["processor", "ram", "gpu", "cpu", "battery", "charging"],
"food": ["processor", "gpu", "engine", "torque", "horsepower"],
"college": ["torque", "engine", "gpu", "charging speed"],
}
# Get blocked terms for this query
blocked = []
for topic, terms in IRRELEVANT_COMBOS.items():
if topic in query.lower():
blocked.extend(terms)
for category, attributes in data.items():
cat = category.strip().title()
if cat not in VALID_CATEGORIES:
print(f"Skipping unknown category: {cat}")
continue
attrs = []
for a in attributes:
if not isinstance(a, str) or not a.strip():
continue
# ✅ Reject attributes containing blocked terms
a_lower = a.lower()
if any(b in a_lower for b in blocked):
print(f"❌ Rejected irrelevant attribute: {a}")
continue
attrs.append(a.strip())
attrs = attrs[:12]
if attrs:
cleaned[cat] = attrs
if not cleaned:
raise ValueError("No valid categories found after validation")
return cleaned
def call_gemini(query: str) -> dict:
prompt = f"""You are a world-class decision analysis expert.
Task: Generate a COMPREHENSIVE list of evaluation criteria for: "{query}"
MANDATORY RULES:
- Select 6-8 categories from the allowed list
- Generate EXACTLY 8-10 attributes per category — this is mandatory
- Each attribute must be SPECIFIC and MEASURABLE for "{query}"
- First 3 attributes in EVERY category must be the MOST SEARCHED specs
- For tech products: always start with Processor, RAM, Battery, Display, Camera
- For vehicles: always start with Engine, Mileage, Price, Safety
- Order strictly by: most searched → most compared → most reviewed
- Use concise names (2-5 words max)
- DO NOT generate less than 8 attributes per category
Allowed category keys:
Performance, Financial, Risk, Maintenance, Benefits, Time, Requirements,
Scalability, Alternatives, Usability, Security, Reliability, Support, Sustainability
Example of CORRECT format with enough attributes:
{{"Performance":["Engine Power","Torque Output","Top Speed","0-100 kmph Time","Fuel Efficiency","Gear Smoothness","Braking Distance","Tyre Grip","Suspension Quality","NVH Levels"],"Financial":["Ex-showroom Price","On-road Price","EMI Options","Insurance Cost","Fuel Cost Monthly","Resale Value","Maintenance Cost","Road Tax","Accessories Cost","Total Ownership Cost"],"Reliability":["Engine Reliability","Electrical Issues","Common Problems","Long Term Durability","Brand Track Record","Owner Satisfaction","Recall History","Service Quality","Spare Parts Life","Warranty Claims"]}}
Now generate for: "{query}"
Return ONLY valid JSON. No markdown. No explanation. Minimum 8 attributes per category."""
url = f"https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/gemini-2.5-flash-lite:generateContent"
for attempt in range(3):
try:
res = http_requests.post(
url,
headers={
"Authorization": f"Bearer {get_access_token()}",
"Content-Type": "application/json"
},
json={
"contents": [{"role": "user", "parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": 0.1, # lower = more accurate, less random
"maxOutputTokens": 1024
}
},
timeout=30
)
if res.status_code == 429:
wait = 2 ** attempt
print(f"Rate limited, waiting {wait}s... (attempt {attempt + 1})")
time.sleep(wait)
continue
print("Status Code:", res.status_code)
if res.status_code != 200:
print("❌ ERROR RESPONSE:")
print(res.text)
raise RuntimeError("Vertex API failed")
data_json = res.json()
print("✅ RAW RESPONSE:", str(data_json)[:500])
raw = res.json()["candidates"][0]["content"]["parts"][0]["text"]
clean = clean_json_response(raw)
data = json.loads(clean)
validated = validate_schema(data,query)
print(f"Gemini generated {sum(len(v) for v in validated.values())} attributes across {len(validated)} categories")
return validated
except (json.JSONDecodeError, ValueError) as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt == 2:
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
raise RuntimeError("Gemini failed after 3 attempts: rate limit")
def bootstrap_domain(query: str):
data = None
for attempt in range(3):
try:
data = call_gemini(query)
print(f"Gemini response validated on attempt {attempt + 1}")
break
except (json.JSONDecodeError, ValueError) as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt == 2:
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
if not data:
raise RuntimeError("No data generated")
# ✅ Quality gate — reject low quality bootstraps
total_attrs = sum(len(v) for v in data.values())
if total_attrs < 10:
raise ValueError(f"Quality gate failed: only {total_attrs} attributes generated")
# ✅ Duplicate domain detection — check similarity before storing
model_local = SentenceTransformer("all-MiniLM-L6-v2")
domain_embedding = model_local.encode(query).tolist()
with engine.begin() as conn:
# Check if very similar domain already exists
existing = conn.execute(text("""
SELECT name, embedding <-> CAST(:emb AS vector) AS distance
FROM domains
ORDER BY distance
LIMIT 1
"""), {"emb": str(domain_embedding)}).fetchone()
if existing and existing.distance < 0.15:
print(f"⚠️ Similar domain already exists: '{existing.name}' (distance: {existing.distance:.3f}) — skipping bootstrap")
return # ✅ Don't store duplicate
# Store domain
domain_id = conn.execute(text("""
INSERT INTO domains (name, embedding)
VALUES (:n, CAST(:e AS vector))
ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name
RETURNING id
"""), {"n": query, "e": str(domain_embedding)}).scalar()
for group, attrs in data.items():
group_id = conn.execute(text("""
INSERT INTO dimension_groups (domain_id, name)
VALUES (:d, :g)
ON CONFLICT (domain_id, name) DO UPDATE SET name = EXCLUDED.name
RETURNING id
"""), {"d": domain_id, "g": group}).scalar()
for attr in attrs:
emb = model_local.encode(attr).tolist()
conn.execute(text("""
INSERT INTO attributes (group_id, name, embedding)
VALUES (:gid, :name, CAST(:emb AS vector))
ON CONFLICT (group_id, name) DO NOTHING
"""), {"gid": group_id, "name": attr, "emb": str(emb)})
print(f"✅ Domain bootstrapped: {query} ({total_attrs} attributes)")