You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
6.2 KiB

4 days ago
# app/bootstrap.py
from dotenv import load_dotenv
import time
import os, json, re, requests
from sqlalchemy import create_engine, text
from sentence_transformers import SentenceTransformer
import requests as http_requests
import google.auth
import google.auth.transport.requests
import os
from app.vertex_client import get_access_token
load_dotenv()
DATABASE_URL = "postgresql://postgres:postgres@localhost:5432/decision_engine"
engine = create_engine(DATABASE_URL)
model = SentenceTransformer("all-MiniLM-L6-v2")
PROJECT_ID = os.getenv("GOOGLE_PROJECT_ID")
LOCATION = "us-central1"
VALID_CATEGORIES = {
"Performance", "Financial", "Risk", "Maintenance", "Benefits",
"Time", "Requirements", "Scalability", "Alternatives", "Usability",
"Security", "Reliability", "Support", "Sustainability"
}
def clean_json_response(raw: str) -> str:
"""Strip markdown, whitespace, and extract pure JSON."""
# Remove markdown code blocks
raw = re.sub(r"```json\s*", "", raw)
raw = re.sub(r"```\s*", "", raw)
raw = raw.strip()
# Extract first JSON object found
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
return match.group(0)
raise ValueError("No JSON object found in response")
def validate_schema(data: dict) -> dict:
"""
Validate and clean Gemini response.
- Only keep valid categories
- Only keep string attributes
- Remove empty categories
- Cap at 5 attributes per category
"""
cleaned = {}
for category, attributes in data.items():
# Normalize category name
cat = category.strip().title()
# Skip invalid categories
if cat not in VALID_CATEGORIES:
print(f"Skipping unknown category: {cat}")
continue
# Only keep string attributes
attrs = [a.strip() for a in attributes if isinstance(a, str) and a.strip()]
# Cap at 5 per category
attrs = attrs[:6]
# Skip empty categories
if attrs:
cleaned[cat] = attrs
if not cleaned:
raise ValueError("No valid categories found after validation")
return cleaned
def call_gemini(query: str) -> dict:
prompt = f"""You are a world-class decision analysis expert.
Task: Generate the MOST IMPORTANT and FREQUENTLY USED evaluation criteria for someone making a decision about: "{query}"
Rules:
- Only include criteria that are HIGHLY RELEVANT to "{query}"
- Prioritize criteria that people MOST COMMONLY consider for this topic
- Each attribute must be SPECIFIC and MEASURABLE, not generic
- Order attributes by importance (most important first)
- Use concise names (2-5 words max per attribute)
Output format: Pure JSON only. No markdown. No explanation.
Use ONLY these category keys: Performance, Financial, Risk, Maintenance, Benefits, Time, Requirements, Scalability, Alternatives, Usability, Security, Reliability, Support, Sustainability
Example for "buying a car":
{{"Performance":["Engine Power","Top Speed","0-100 Acceleration","Fuel Efficiency"],"Financial":["Purchase Price","Insurance Cost","Resale Value","Running Cost"],"Maintenance":["Service Interval","Spare Parts Availability","Warranty Period"]}}
Now generate for: "{query}"
Return ONLY the JSON object."""
url = f"https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/gemini-2.5-flash-lite:generateContent"
for attempt in range(3):
try:
res = http_requests.post(url,
headers={
"Authorization": f"Bearer {get_access_token()}",
"Content-Type": "application/json"},
json={
"contents": [{"role" : "user","parts": [{"text": prompt}]}],
"generationConfig": {"temperature": 0.2, "maxOutputTokens": 1024}})
# Handle rate limit with exponential backoff
if res.status_code == 429:
wait = 2 ** attempt # 1s, 2s, 4s
print(f" Rate limited, waiting {wait}s... (attempt {attempt + 1})")
time.sleep(wait)
continue
res.raise_for_status()
raw = res.json()["candidates"][0]["content"]["parts"][0]["text"]
clean = clean_json_response(raw)
data = json.loads(clean)
validated = validate_schema(data)
return validated
except (json.JSONDecodeError, ValueError) as e:
print(f" Attempt {attempt + 1} failed: {e}")
if attempt == 2:
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
raise RuntimeError("Gemini failed after 3 attempts: rate limit")
def bootstrap_domain(query: str):
# Retry up to 3 times if Gemini returns bad JSON
data = None
for attempt in range(3):
try:
data = call_gemini(query)
print(f"Gemini response validated on attempt {attempt + 1}")
break
except (json.JSONDecodeError, ValueError) as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt == 2:
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
with engine.begin() as conn:
domain_embedding = model.encode(query).tolist()
domain_id = conn.execute(text("""
INSERT INTO domains (name, embedding)
VALUES (:n, CAST(:e AS vector))
ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name
RETURNING id
"""), {"n": query, "e": str(domain_embedding)}).scalar()
for group, attrs in data.items():
group_id = conn.execute(text("""
INSERT INTO dimension_groups (domain_id, name)
VALUES (:d, :g)
ON CONFLICT (domain_id, name) DO UPDATE SET name = EXCLUDED.name
RETURNING id
"""), {"d": domain_id, "g": group}).scalar()
for attr in attrs:
emb = model.encode(attr).tolist()
conn.execute(text("""
INSERT INTO attributes (group_id, name, embedding)
VALUES (:gid, :name, CAST(:emb AS vector))
ON CONFLICT (group_id, name) DO NOTHING
"""), {"gid": group_id, "name": attr, "emb": str(emb)})
print(f"Domain bootstrapped: {query}")