You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
177 lines
6.2 KiB
177 lines
6.2 KiB
|
4 days ago
|
# app/bootstrap.py
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
import time
|
||
|
|
import os, json, re, requests
|
||
|
|
from sqlalchemy import create_engine, text
|
||
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
import requests as http_requests
|
||
|
|
import google.auth
|
||
|
|
import google.auth.transport.requests
|
||
|
|
import os
|
||
|
|
from app.vertex_client import get_access_token
|
||
|
|
|
||
|
|
load_dotenv()
|
||
|
|
|
||
|
|
DATABASE_URL = "postgresql://postgres:postgres@localhost:5432/decision_engine"
|
||
|
|
engine = create_engine(DATABASE_URL)
|
||
|
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||
|
|
|
||
|
|
PROJECT_ID = os.getenv("GOOGLE_PROJECT_ID")
|
||
|
|
LOCATION = "us-central1"
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
VALID_CATEGORIES = {
|
||
|
|
"Performance", "Financial", "Risk", "Maintenance", "Benefits",
|
||
|
|
"Time", "Requirements", "Scalability", "Alternatives", "Usability",
|
||
|
|
"Security", "Reliability", "Support", "Sustainability"
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def clean_json_response(raw: str) -> str:
|
||
|
|
"""Strip markdown, whitespace, and extract pure JSON."""
|
||
|
|
# Remove markdown code blocks
|
||
|
|
raw = re.sub(r"```json\s*", "", raw)
|
||
|
|
raw = re.sub(r"```\s*", "", raw)
|
||
|
|
raw = raw.strip()
|
||
|
|
|
||
|
|
# Extract first JSON object found
|
||
|
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||
|
|
if match:
|
||
|
|
return match.group(0)
|
||
|
|
raise ValueError("No JSON object found in response")
|
||
|
|
|
||
|
|
|
||
|
|
def validate_schema(data: dict) -> dict:
|
||
|
|
"""
|
||
|
|
Validate and clean Gemini response.
|
||
|
|
- Only keep valid categories
|
||
|
|
- Only keep string attributes
|
||
|
|
- Remove empty categories
|
||
|
|
- Cap at 5 attributes per category
|
||
|
|
"""
|
||
|
|
cleaned = {}
|
||
|
|
for category, attributes in data.items():
|
||
|
|
# Normalize category name
|
||
|
|
cat = category.strip().title()
|
||
|
|
|
||
|
|
# Skip invalid categories
|
||
|
|
if cat not in VALID_CATEGORIES:
|
||
|
|
print(f"Skipping unknown category: {cat}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Only keep string attributes
|
||
|
|
attrs = [a.strip() for a in attributes if isinstance(a, str) and a.strip()]
|
||
|
|
|
||
|
|
# Cap at 5 per category
|
||
|
|
attrs = attrs[:6]
|
||
|
|
|
||
|
|
# Skip empty categories
|
||
|
|
if attrs:
|
||
|
|
cleaned[cat] = attrs
|
||
|
|
|
||
|
|
if not cleaned:
|
||
|
|
raise ValueError("No valid categories found after validation")
|
||
|
|
|
||
|
|
return cleaned
|
||
|
|
|
||
|
|
|
||
|
|
def call_gemini(query: str) -> dict:
|
||
|
|
|
||
|
|
|
||
|
|
prompt = f"""You are a world-class decision analysis expert.
|
||
|
|
|
||
|
|
Task: Generate the MOST IMPORTANT and FREQUENTLY USED evaluation criteria for someone making a decision about: "{query}"
|
||
|
|
|
||
|
|
Rules:
|
||
|
|
- Only include criteria that are HIGHLY RELEVANT to "{query}"
|
||
|
|
- Prioritize criteria that people MOST COMMONLY consider for this topic
|
||
|
|
- Each attribute must be SPECIFIC and MEASURABLE, not generic
|
||
|
|
- Order attributes by importance (most important first)
|
||
|
|
- Use concise names (2-5 words max per attribute)
|
||
|
|
|
||
|
|
Output format: Pure JSON only. No markdown. No explanation.
|
||
|
|
Use ONLY these category keys: Performance, Financial, Risk, Maintenance, Benefits, Time, Requirements, Scalability, Alternatives, Usability, Security, Reliability, Support, Sustainability
|
||
|
|
|
||
|
|
Example for "buying a car":
|
||
|
|
{{"Performance":["Engine Power","Top Speed","0-100 Acceleration","Fuel Efficiency"],"Financial":["Purchase Price","Insurance Cost","Resale Value","Running Cost"],"Maintenance":["Service Interval","Spare Parts Availability","Warranty Period"]}}
|
||
|
|
|
||
|
|
Now generate for: "{query}"
|
||
|
|
Return ONLY the JSON object."""
|
||
|
|
|
||
|
|
url = f"https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/gemini-2.5-flash-lite:generateContent"
|
||
|
|
|
||
|
|
for attempt in range(3):
|
||
|
|
try:
|
||
|
|
res = http_requests.post(url,
|
||
|
|
headers={
|
||
|
|
"Authorization": f"Bearer {get_access_token()}",
|
||
|
|
"Content-Type": "application/json"},
|
||
|
|
json={
|
||
|
|
"contents": [{"role" : "user","parts": [{"text": prompt}]}],
|
||
|
|
"generationConfig": {"temperature": 0.2, "maxOutputTokens": 1024}})
|
||
|
|
|
||
|
|
# Handle rate limit with exponential backoff
|
||
|
|
if res.status_code == 429:
|
||
|
|
wait = 2 ** attempt # 1s, 2s, 4s
|
||
|
|
print(f" Rate limited, waiting {wait}s... (attempt {attempt + 1})")
|
||
|
|
time.sleep(wait)
|
||
|
|
continue
|
||
|
|
|
||
|
|
res.raise_for_status()
|
||
|
|
raw = res.json()["candidates"][0]["content"]["parts"][0]["text"]
|
||
|
|
|
||
|
|
clean = clean_json_response(raw)
|
||
|
|
data = json.loads(clean)
|
||
|
|
validated = validate_schema(data)
|
||
|
|
return validated
|
||
|
|
|
||
|
|
except (json.JSONDecodeError, ValueError) as e:
|
||
|
|
print(f" Attempt {attempt + 1} failed: {e}")
|
||
|
|
if attempt == 2:
|
||
|
|
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
|
||
|
|
|
||
|
|
raise RuntimeError("Gemini failed after 3 attempts: rate limit")
|
||
|
|
|
||
|
|
|
||
|
|
def bootstrap_domain(query: str):
|
||
|
|
|
||
|
|
# Retry up to 3 times if Gemini returns bad JSON
|
||
|
|
data = None
|
||
|
|
for attempt in range(3):
|
||
|
|
try:
|
||
|
|
data = call_gemini(query)
|
||
|
|
print(f"Gemini response validated on attempt {attempt + 1}")
|
||
|
|
break
|
||
|
|
except (json.JSONDecodeError, ValueError) as e:
|
||
|
|
print(f"Attempt {attempt + 1} failed: {e}")
|
||
|
|
if attempt == 2:
|
||
|
|
raise RuntimeError(f"Gemini failed after 3 attempts: {e}")
|
||
|
|
|
||
|
|
with engine.begin() as conn:
|
||
|
|
domain_embedding = model.encode(query).tolist()
|
||
|
|
|
||
|
|
domain_id = conn.execute(text("""
|
||
|
|
INSERT INTO domains (name, embedding)
|
||
|
|
VALUES (:n, CAST(:e AS vector))
|
||
|
|
ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name
|
||
|
|
RETURNING id
|
||
|
|
"""), {"n": query, "e": str(domain_embedding)}).scalar()
|
||
|
|
|
||
|
|
for group, attrs in data.items():
|
||
|
|
group_id = conn.execute(text("""
|
||
|
|
INSERT INTO dimension_groups (domain_id, name)
|
||
|
|
VALUES (:d, :g)
|
||
|
|
ON CONFLICT (domain_id, name) DO UPDATE SET name = EXCLUDED.name
|
||
|
|
RETURNING id
|
||
|
|
"""), {"d": domain_id, "g": group}).scalar()
|
||
|
|
|
||
|
|
for attr in attrs:
|
||
|
|
emb = model.encode(attr).tolist()
|
||
|
|
conn.execute(text("""
|
||
|
|
INSERT INTO attributes (group_id, name, embedding)
|
||
|
|
VALUES (:gid, :name, CAST(:emb AS vector))
|
||
|
|
ON CONFLICT (group_id, name) DO NOTHING
|
||
|
|
"""), {"gid": group_id, "name": attr, "emb": str(emb)})
|
||
|
|
|
||
|
|
print(f"Domain bootstrapped: {query}")
|