from flask import Flask, render_template, request, redirect, url_for, session, jsonify
import os
import utils
from auto_prompts import build_prompt, load_schema

app = Flask(__name__)
app.secret_key = os.environ.get("FLASK_SECRET_KEY", "change_me_in_prod")

USERNAME = os.environ.get("APP_USERNAME", "szellem")
PASSWORD = os.environ.get("APP_PASSWORD", "casper+2026")


@app.route("/", methods=["GET", "POST"])
def login():
    if session.get("logged_in"):
        return redirect(url_for("index"))

    error = None
    if request.method == "POST":
        if request.form.get("username") != USERNAME or request.form.get("password") != PASSWORD:
            error = "Invalid Credentials. Please try again."
        else:
            session["logged_in"] = True
            return redirect(url_for("index"))

    return render_template("login.html", error=error)


@app.route("/app", methods=["GET"])
def index():
    if not session.get("logged_in"):
        return redirect(url_for("login"))
    return render_template("index.html")


@app.route("/evaluate", methods=["POST"])
def evaluate():

    if not session.get("logged_in"):
        return jsonify({"ok": False, "error": "Not logged in"}), 401

    data = request.get_json(silent=True) or {}
    content = (data.get("content") or "").strip()
    symptoms = data.get("symptoms")
    print(symptoms)
    if not content:
        return jsonify({"ok": False, "error": "Missing content"}), 400
    if not isinstance(symptoms, list) or len(symptoms) == 0:
        return jsonify({"ok": False, "error": "Missing symptoms"}), 400

    try:
        # 🔹 load schema
        schema = load_schema("static/symptoms_v3.yaml")

        # 🔹 map key → group_id
        key_to_group = {
            f["key"]: f.get("group_id")
            for f in schema["fields"]
        }

        # 🔹 group selected symptoms
        groups = {}
        for s in symptoms:
            gid = key_to_group.get(s)
            if gid is None:
                continue
            groups.setdefault(gid, []).append(s)

        final_results = {}
        raw_outputs = {}

        # 🔥 ONE LLM CALL PER GROUP
        for gid in groups.keys():

            # ✅ build prompt
            prompt = build_prompt(
                content,
                schema_path="static/symptoms_v3.yaml",
                group_id=gid
            )
            print(prompt)
            f = open('prompt.txt', 'w')
            f.write(prompt)
            f.close()
            # ✅ call LLM
            llm_text = utils.call_llm(prompt)

            # ✅ parse JSON
            parsed = utils.extract_json_block(llm_text)

            raw_outputs[gid] = llm_text
            if parsed:
                # 🔹 only keep selected symptoms
                filtered = {
                k: v for k, v in parsed.items()
                if k in symptoms
                }
                final_results.update(filtered)
        return jsonify({
            "ok": True,
            "results": final_results,
            "raw": raw_outputs
        })

    except Exception as e:
        return jsonify({"ok": False, "error": str(e)}), 500


@app.route("/logout", methods=["POST", "GET"])
def logout():
    session.clear()
    return redirect(url_for("login"))


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8085, debug=True)